예제 #1
0
    def load_predictor(self):
        # If the predict step is enabled
        if self.config.data.enabled.predict:
            self.loading_message("Predictor")
            i = self.get_dataset_index(DEEP_DATASET_PREDICTION)
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.predict.get(ignore="outputs"))

            # Output Transform Manager
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.predict.get("outputs"))

            # Initialise prediction dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise predictor
            #self.predictor = Predictor(
            #    **self.config.data.dataloader.get(),
            #    batch_size = self.config.data.datasets[i].batch_size,
            #    name="Predictor",
            #    model=self.model,
            #    dataset=dataset,
            #    transform_manager=output_transform_manager
            #)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_PREDICTION.name)
예제 #2
0
    def __init__(self,
                 model: Module,
                 dataset: Dataset,
                 batch_size: int = 4,
                 num_workers: int = 4):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize a GenericInferer instance

        PARAMETERS:
        -----------

        :param model (torch.nn.Module): The model to infer
        :param dataset (Dataset): A dataset
        :param batch_size (int): The number of instances per batch
        :param num_workers (int): The number of processes / threads used for data loading
        """

        self.model = model
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.dataset = dataset
        self.dataloader = DataLoader(dataset=dataset,
                                     batch_size=batch_size,
                                     shuffle=False,
                                     num_workers=num_workers)
        self.num_minibatches = self.compute_num_minibatches(
            batch_size=batch_size, length_dataset=dataset.__len__())
예제 #3
0
    def load_predictor(self):
        # If the predict step is enabled
        if self.config.data.enabled.predict:

            predict_index = self.get_dataset_index("predict")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[predict_index].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.predict.get(ignore="outputs"))

            # Output Transform Manager
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.predict.get("outputs"))

            # Dataset
            dataset = Dataset(**self.config.data.datasets[predict_index].get(
                ignore="type"),
                              transform_manager=transform_manager)

            # Predictor
            self.predictor = Predictor(
                **self.config.data.dataloader.get(),
                model=self.model,
                dataset=dataset,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Prediction set")
예제 #4
0
    def load_trainer(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load a trainer

        PARAMETERS:
        -----------
        None

        RETURN:
        -------

        :return None
        """
        # If the train step is enabled
        if self.config.data.enabled.train:
            self.loading_message("Trainer")
            # Input transform manager
            transform_manager = TransformManager(
                **self.config.transform.train.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.train.get("outputs"))

            # Initialise training dataset
            i = self.get_dataset_index(
                DEEP_DATASET_TRAIN)  # Get index of train dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise scheduler
            if self.optimizer is not None and self.config.training.scheduler.enabled:
                self.load_scheduler()

            # Initialise trainer
            self.trainer = Trainer(
                dataset,
                **self.config.data.dataloader.get(),
                batch_size=self.config.data.datasets[i].batch_size,
                model=self.model,
                metrics=self.metrics,
                losses=self.losses,
                optimizer=self.optimizer,
                scheduler=self.scheduler,
                **self.config.training.get(
                    ignore=["overwatch", "saver", "scheduler"]),
                validator=self.validator,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO, "Trainer disabled")
예제 #5
0
    def load_validator(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the validation inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the validation step is enabled
        if self.config.data.enabled.validation:
            self.loading_message("Validator")
            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.validation.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.validation.get(
                    "outputs"))

            # Initialise validation dataset
            i = self.get_dataset_index(DEEP_DATASET_VAL)
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise validator
            self.validator = Tester(
                **self.config.data.dataloader.get(),
                batch_size=self.config.data.datasets[i].batch_size,
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                transform_manager=output_transform_manager,
                name="Validator")

            # Update trainer.tester with this new validator
            if self.trainer is not None:
                self.trainer.tester = self.validator
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_VAL.name)
예제 #6
0
    def load_validator(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the validation inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the validation step is enabled
        if self.config.data.enabled.validation:

            validation_index = self.get_dataset_index("validation")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[validation_index].name)

            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.validation.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.validation.get(
                    "outputs"))
            # output_transformer.summary()

            # Dataset
            dataset = Dataset(**self.config.data.datasets[validation_index].
                              get(ignore="type"),
                              transform_manager=transform_manager)

            # Validator
            self.validator = Tester(**self.config.data.dataloader.get(),
                                    model=self.model,
                                    dataset=dataset,
                                    metrics=self.metrics,
                                    losses=self.losses,
                                    transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Validation set")
예제 #7
0
    def load_tester(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the test inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the test step is enabled
        if self.config.data.enabled.test:
            self.loading_message("Tester")
            i = self.get_dataset_index(DEEP_DATASET_TEST)
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.test.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.test.get("outputs"))

            # Initialise test dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise tester
            self.tester = Tester(
                **self.config.data.dataloader.get(),
                batch_size=self.config.data.datasets[i].batch_size,
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_TEST.name)
예제 #8
0
    def load_trainer(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load a trainer

        PARAMETERS:
        -----------
        None

        RETURN:
        -------

        :return None
        """
        # If the train step is enabled
        if self.config.data.enabled.train:
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_NOTIF_DATA_LOADING % self.config.data.dataset.train.name)

            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.train.get())

            # Dataset
            dataset = Dataset(**self.config.data.dataset.train.get(),
                              transform_manager=transform_manager,
                              cv_library=self.config.project.cv_library)
            # Trainer
            self.trainer = Trainer(
                **self.config.data.dataloader.get(),
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                optimizer=self.optimizer,
                num_epochs=self.config.training.num_epochs,
                initial_epoch=self.config.training.initial_epoch,
                shuffle_method=self.config.training.shuffle,
                verbose=self.config.history.verbose,
                tester=self.validator)
        else:
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_MSG_DATA_DISABLED % self.config.data.dataset.train.name)
예제 #9
0
    def load_validator(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the validation inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the validation step is enabled
        if self.config.data.enabled.validation:
            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.dataset.validation.name)

            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.validation.get())

            # Dataset
            dataset = Dataset(**self.config.data.dataset.validation.get(),
                              transform_manager=transform_manager,
                              cv_library=self.config.project.cv_library)

            # Validator
            self.validator = Tester(**self.config.data.dataloader.get(),
                                    model=self.model,
                                    dataset=dataset,
                                    metrics=self.metrics,
                                    losses=self.losses)
        else:
            Notification(
                DEEP_NOTIF_INFO, DEEP_MSG_DATA_DISABLED %
                self.config.data.dataset.validation.name)
예제 #10
0
    def load_predictor(self):
        # If the predict step is enabled
        if self.config.data.enabled.predict:
            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.dataset.predict.name)

            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.predict.get(ignore="outputs"))

            # Dataset
            dataset = Dataset(**self.config.data.dataset.predict.get(),
                              transform_manager=transform_manager,
                              cv_library=self.config.project.cv_library)
            # Predictor
            self.predictor = Predictor(**self.config.data.dataloader.get(),
                                       model=self.model,
                                       dataset=dataset)
        else:
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_MSG_DATA_DISABLED % self.config.data.dataset.predict.name)
예제 #11
0
    def load_trainer(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load a trainer

        PARAMETERS:
        -----------
        None

        RETURN:
        -------

        :return None
        """
        # If the train step is enabled
        if self.config.data.enabled.train:
            train_index = self.get_dataset_index("train")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[train_index].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.train.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.train.get("outputs"))
            # output_transformer.summary()

            # Dataset
            dataset = Dataset(**self.config.data.datasets[train_index].get(
                ignore="type"),
                              transform_manager=transform_manager)

            # Trainer
            self.trainer = Trainer(
                **self.config.data.dataloader.get(),
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                optimizer=self.optimizer,
                num_epochs=self.config.training.num_epochs,
                initial_epoch=self.config.training.initial_epoch,
                shuffle_method=self.config.training.shuffle,
                verbose=self.config.history.verbose,
                tester=self.validator,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Training set")