Exemple #1
0
    def load_trainer(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load a trainer

        PARAMETERS:
        -----------
        None

        RETURN:
        -------

        :return None
        """
        # If the train step is enabled
        if self.config.data.enabled.train:
            self.loading_message("Trainer")
            # Input transform manager
            transform_manager = TransformManager(
                **self.config.transform.train.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.train.get("outputs"))

            # Initialise training dataset
            i = self.get_dataset_index(
                DEEP_DATASET_TRAIN)  # Get index of train dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise scheduler
            if self.optimizer is not None and self.config.training.scheduler.enabled:
                self.load_scheduler()

            # Initialise trainer
            self.trainer = Trainer(
                dataset,
                **self.config.data.dataloader.get(),
                batch_size=self.config.data.datasets[i].batch_size,
                model=self.model,
                metrics=self.metrics,
                losses=self.losses,
                optimizer=self.optimizer,
                scheduler=self.scheduler,
                **self.config.training.get(
                    ignore=["overwatch", "saver", "scheduler"]),
                validator=self.validator,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO, "Trainer disabled")
Exemple #2
0
    def load_predictor(self):
        # If the predict step is enabled
        if self.config.data.enabled.predict:
            self.loading_message("Predictor")
            i = self.get_dataset_index(DEEP_DATASET_PREDICTION)
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.predict.get(ignore="outputs"))

            # Output Transform Manager
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.predict.get("outputs"))

            # Initialise prediction dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise predictor
            #self.predictor = Predictor(
            #    **self.config.data.dataloader.get(),
            #    batch_size = self.config.data.datasets[i].batch_size,
            #    name="Predictor",
            #    model=self.model,
            #    dataset=dataset,
            #    transform_manager=output_transform_manager
            #)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_PREDICTION.name)
Exemple #3
0
    def load_predictor(self):
        # If the predict step is enabled
        if self.config.data.enabled.predict:

            predict_index = self.get_dataset_index("predict")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[predict_index].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.predict.get(ignore="outputs"))

            # Output Transform Manager
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.predict.get("outputs"))

            # Dataset
            dataset = Dataset(**self.config.data.datasets[predict_index].get(
                ignore="type"),
                              transform_manager=transform_manager)

            # Predictor
            self.predictor = Predictor(
                **self.config.data.dataloader.get(),
                model=self.model,
                dataset=dataset,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Prediction set")
Exemple #4
0
    def load_validator(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the validation inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the validation step is enabled
        if self.config.data.enabled.validation:
            self.loading_message("Validator")
            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.validation.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.validation.get(
                    "outputs"))

            # Initialise validation dataset
            i = self.get_dataset_index(DEEP_DATASET_VAL)
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise validator
            self.validator = Tester(
                **self.config.data.dataloader.get(),
                batch_size=self.config.data.datasets[i].batch_size,
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                transform_manager=output_transform_manager,
                name="Validator")

            # Update trainer.tester with this new validator
            if self.trainer is not None:
                self.trainer.tester = self.validator
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_VAL.name)
Exemple #5
0
    def load_validator(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the validation inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the validation step is enabled
        if self.config.data.enabled.validation:

            validation_index = self.get_dataset_index("validation")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[validation_index].name)

            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.validation.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.validation.get(
                    "outputs"))
            # output_transformer.summary()

            # Dataset
            dataset = Dataset(**self.config.data.datasets[validation_index].
                              get(ignore="type"),
                              transform_manager=transform_manager)

            # Validator
            self.validator = Tester(**self.config.data.dataloader.get(),
                                    model=self.model,
                                    dataset=dataset,
                                    metrics=self.metrics,
                                    losses=self.losses,
                                    transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Validation set")
Exemple #6
0
    def load_tester(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the test inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the test step is enabled
        if self.config.data.enabled.test:
            self.loading_message("Tester")
            i = self.get_dataset_index(DEEP_DATASET_TEST)
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.test.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.test.get("outputs"))

            # Initialise test dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise tester
            self.tester = Tester(
                **self.config.data.dataloader.get(),
                batch_size=self.config.data.datasets[i].batch_size,
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_TEST.name)
Exemple #7
0
    def load_trainer(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load a trainer

        PARAMETERS:
        -----------
        None

        RETURN:
        -------

        :return None
        """
        # If the train step is enabled
        if self.config.data.enabled.train:
            train_index = self.get_dataset_index("train")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[train_index].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.train.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.train.get("outputs"))
            # output_transformer.summary()

            # Dataset
            dataset = Dataset(**self.config.data.datasets[train_index].get(
                ignore="type"),
                              transform_manager=transform_manager)

            # Trainer
            self.trainer = Trainer(
                **self.config.data.dataloader.get(),
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                optimizer=self.optimizer,
                num_epochs=self.config.training.num_epochs,
                initial_epoch=self.config.training.initial_epoch,
                shuffle_method=self.config.training.shuffle,
                verbose=self.config.history.verbose,
                tester=self.validator,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Training set")