Beispiel #1
0
    def load_predictor(self):
        # If the predict step is enabled
        if self.config.data.enabled.predict:
            self.loading_message("Predictor")
            i = self.get_dataset_index(DEEP_DATASET_PREDICTION)
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.predict.get(ignore="outputs"))

            # Output Transform Manager
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.predict.get("outputs"))

            # Initialise prediction dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise predictor
            #self.predictor = Predictor(
            #    **self.config.data.dataloader.get(),
            #    batch_size = self.config.data.datasets[i].batch_size,
            #    name="Predictor",
            #    model=self.model,
            #    dataset=dataset,
            #    transform_manager=output_transform_manager
            #)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_PREDICTION.name)
Beispiel #2
0
    def fit(self, first_training: bool = True) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Fit the model to the dataset

        PARAMETERS:
        -----------

        :param first_training: (bool, optional): Whether it is the first training on the model or not

        RETURN:
        -------

        :return: None
        """
        Notification(DEEP_NOTIF_INFO, DEEP_MSG_TRAINING_STARTED)
        self.__train(first_training=first_training)
        Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_TRAINING_FINISHED)
Beispiel #3
0
    def summary(self) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Print the summary of the Pointer

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        Notification(DEEP_NOTIF_INFO, "------------------------------------")
        Notification(DEEP_NOTIF_INFO,
                     "Transformer '" + str(self.name) + "' summary :")
        Notification(
            DEEP_NOTIF_INFO, "Points to the  %ith %s." %
            (self.transformer_index, self.transformer_entry.get_name()))
        Notification(DEEP_NOTIF_INFO, "------------------------------------")
Beispiel #4
0
    def __import_cv_library(self):
        """
        AUTHORS:
        --------
        author: Samuel Westlake

        DESCRIPTION:
        ------------
        Imports either cv2 or PIL.Image dependant on the value of self.cv_library

        PARAMETERS:
        -----------
        None

        RETURN:
        -------
        None
        """
        if self.cv_library == DEEP_LIB_OPENCV:
            try:
                Notification(DEEP_NOTIF_INFO, DEEP_MSG_CV_LIBRARY_SET % "OPENCV")
                global cv2
                import cv2
            except ImportError as e:
                Notification(DEEP_NOTIF_ERROR, str(e))
        elif self.cv_library == DEEP_LIB_PIL:
            try:
                Notification(DEEP_NOTIF_INFO, DEEP_MSG_CV_LIBRARY_SET % "PILLOW")
                global Image
                from PIL import Image
            except ImportError as e:
                Notification(DEEP_NOTIF_ERROR, str(e))
        else:
            Notification(DEEP_NOTIF_ERROR, DEEP_MSG_CV_LIBRARY_NOT_IMPLEMENTED % self.cv_library)
 def load_transform_functions(self):
     """
     Calls the get_module function for each transform in each transformer
     Loads the returned method and edits the module entry to reflect the origin
     :return: None
     """
     # For each transform sequence
     for i, sequence in enumerate(self.output_transformer):
         # Check that the transforms entry exists
         self.__check_transforms_exists(sequence, i)
         if sequence.transforms is not None:
             for transform_name, transform_info in sequence.transforms.get(
             ).items():
                 self.__transform_loading(transform_name, transform_info)
                 method, module_path = get_module(
                     **transform_info.get(ignore="kwargs"),
                     browse=DEEP_MODULE_TRANSFORMS)
                 if method is None:
                     self.__method_not_found(transform_info)
                 if isinstance(method, types.FunctionType):
                     transform_info.add({"method": method})
                 else:
                     try:
                         transform_info.add({
                             "method":
                             method(**transform_info.kwargs.get())
                         })
                     except TypeError as e:
                         Notification(DEEP_NOTIF_FATAL, str(e))
                 transform_info.module = module_path
                 Notification(
                     DEEP_NOTIF_SUCCESS,
                     "Loaded transform : %s : %s from %s" %
                     (transform_name, transform_info.name,
                      transform_info.module))
Beispiel #6
0
    def __thanks_master():
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Display thanks message

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: Universal Love <3
        """
        Notification(DEEP_NOTIF_INFO, "=================================")
        Notification(DEEP_NOTIF_INFO, "Thank you for using Deeplodocus !")
        Notification(DEEP_NOTIF_INFO, "== Made by Humans with deep <3 ==")
        Notification(DEEP_NOTIF_INFO, "=================================")
Beispiel #7
0
    def train(self, num_epochs: Union[int, None] = None):
        # Pre-training checks
        if self.model is None:
            Notification(
                DEEP_NOTIF_ERROR,
                "Could not begin training : No model detected by the trainer")
        if self.losses is None:
            Notification(
                DEEP_NOTIF_ERROR,
                "Could not begin training : No losses detected by the trainer")
        if self.optimizer is None:
            Notification(
                DEEP_NOTIF_ERROR,
                "Could not begin training : No optimizer detected by the trainer"
            )

        # Update num_epochs
        self.num_epochs = self.num_epochs if num_epochs is None else num_epochs

        # Infer initial epoch
        if self.initial_epoch is None:
            self.initial_epoch = self.model.epoch if "epoch" in vars(
                self.model).keys() else 0

        # Go
        self.training_start()
        for self.epoch in range(self.initial_epoch + 1,
                                self.num_epochs + self.initial_epoch + 1):
            self.epoch_start()
            for self.batch_index, batch in enumerate(self.dataloader, 1):
                self.forward(batch) if self.accumulate == 1 else self.forward2(
                    batch)
            self.epoch_end()
        self.training_end()
Beispiel #8
0
    def summary():
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Print the summary of the NoTransformer

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        Notification(DEEP_NOTIF_INFO, "------------------------------------")
        Notification(DEEP_NOTIF_INFO, "No Transformer for this entry")
        Notification(DEEP_NOTIF_INFO, "------------------------------------")
    def __set_number(self) -> None:
        """
        AUTHORS:
        --------
        author: Samuel Westlake, Alix Leroy

        DESCRIPTION:
        ------------
        Set the length of the dataset

        RETURN:
        -------
        :return: None
        """
        if self.number is None:
            self.number = len(self.data)
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_NO_LENGTH % len(self.data))
        else:
            if self.number > len(self.data):
                self.number = len(self.data)
                Notification(
                    DEEP_NOTIF_WARNING,
                    DEEP_MSG_DATA_TOO_LONG % (self.number, len(self.data)))
            else:
                Notification(
                    DEEP_NOTIF_INFO,
                    DEEP_MSG_DATA_LENGTH % (len(self.data), self.number))
    def shuffle(self, method: int) -> None:
        """
        AUTHORS:
        --------
        author: Alix Leroy

        DESCRIPTION:
        ------------
        Shuffle the dataframe containing the data

        PARAMETERS:
        -----------
        None

        RETURN:
        -------
        :return: None
        """

        if method == DEEP_SHUFFLE_ALL:
            try:
                self.data = self.data.sample(frac=1).reset_index(drop=True)
            # TODO: Please can this except a specific error(s)
            except:
                Notification(DEEP_NOTIF_ERROR, "Cannot shuffle the dataset")
        else:
            Notification(DEEP_NOTIF_ERROR,
                         "The shuffling method does not exist.")

        # Reset the TransformManager
        self.reset()
    def generate_sources(self, sources: List[dict]) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Generate the sources
        Does not generate the SourcePointer instances

        PARAMETERS:
        -----------

        :param sources(List[dict]): The configuration of the Source instances to generate

        RETURN:
        -------

        :return: None
        """
        list_sources = []
        # Create sources
        for i, source in enumerate(sources):

            # Get the Source module and add it to the list
            m = "default modules" if source["module"] is None else source["module"]
            Notification(DEEP_NOTIF_INFO, DEEP_MSG_LOADING % ("source [%i]" % i, source["name"], m))
            method, module_path = get_module(name=source["name"], module=source["module"], browse=DEEP_MODULE_SOURCES)
            if method is None:
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_MODULE_NOT_FOUND % (source["name"], module_path))
            else:
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_LOADED % ("source [%i]" % i, source["name"], module_path))

            # If the source is not a real source
            if not issubclass(method, Source):
                # Remove the id from the initial kwargs
                index = source["kwargs"].pop('index', None)

                # Create a source wrapper with the new ID
                s = SourceWrapper(
                    index=index,
                    name=source["name"],
                    module=source["module"],
                    kwargs=source["kwargs"]
                )
            else:
                # If the subclass is a real Source
                s = method(**source["kwargs"])

            # Check the module inherits the generic Source class
            self.check_type_sources(s, i)

            # Add the module to the list of Source instances
            list_sources.append(s)

        # Set the list as the attribute
        self.sources = list_sources
Beispiel #12
0
    def __check_directory(self) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check the given directory exists

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        if not os.path.isdir(self.path):
            Notification(
                DEEP_NOTIF_FATAL,
                "The following path is not a Source folder : %s " % self.path)
        else:
            Notification(DEEP_NOTIF_SUCCESS,
                         "Source folder \"%s\" successfully found" % self.path)
Beispiel #13
0
    def load_optimizer(self):
        """
        AUTHORS:
        --------

        :author: Samuel Westlake
        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the optimizer with the adequate parameters

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        if self.model is not None:
            self.optimizer = Optimizer(self.model.parameters(),
                                       self.config.optimizer).get()
            Notification(
                DEEP_NOTIF_SUCCESS, DEEP_MSG_OPTIMIZER_LOADED %
                (self.config.optimizer.name, self.optimizer.__module__))
        else:
            Notification(DEEP_NOTIF_ERROR,
                         DEEP_MSG_OPTIMIZER_NOT_LOADED % DEEP_MSG_MODEL_LOADED)
Beispiel #14
0
    def load_predictor(self):
        # If the predict step is enabled
        if self.config.data.enabled.predict:

            predict_index = self.get_dataset_index("predict")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[predict_index].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.predict.get(ignore="outputs"))

            # Output Transform Manager
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.predict.get("outputs"))

            # Dataset
            dataset = Dataset(**self.config.data.datasets[predict_index].get(
                ignore="type"),
                              transform_manager=transform_manager)

            # Predictor
            self.predictor = Predictor(
                **self.config.data.dataloader.get(),
                model=self.model,
                dataset=dataset,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Prediction set")
 def __import(self, *args):
     for arg in args:
         items = []
         for m in DEEP_LIST_MODULE:
             method, module_path = get_module(arg, browse=m)
             if method is not None:
                 items.append({
                     "method": method,
                     "module_path": module_path,
                     "file": inspect.getsourcefile(method),
                     "prefix": m["custom"]["prefix"],
                     "name": arg
                 })
         if not items:
             Notification(DEEP_NOTIF_ERROR, "module not found : %s" % arg)
         elif len(items) == 1:
             self.__import_module(items[0])
         else:
             Notification(DEEP_NOTIF_INFO,
                          "Multiple modules found with the same name :")
             for i, item in enumerate(items):
                 Notification(
                     DEEP_NOTIF_INFO,
                     "  %i. %s from %s" % (i, arg, item["module_path"]))
             while True:
                 i = Notification(DEEP_NOTIF_INPUT,
                                  "Which would you like?").get()
                 try:
                     i = int(i)
                     break
                 except ValueError:
                     pass
             self.__import_module(items[i])
 def __check_args(self) -> dict:
     if inspect.isfunction(self.method):
         args = inspect.getfullargspec(self.method).args
     else:
         args = inspect.getfullargspec(self.method.forward).args
         args.remove("self")
     args_dict = {}
     for arg in args:
         if self.is_custom:
             entry_flag = get_corresponding_flag(DEEP_LIST_ENTRY,
                                                 arg,
                                                 fatal=False)
             if entry_flag is None:
                 Notification(DEEP_NOTIF_FATAL,
                              DEEP_MSG_LOSS_UNEXPECTED_ARG % arg)
             else:
                 args_dict[entry_flag] = arg
         else:
             if arg in ("input", "x", "inputs"):
                 args_dict[DEEP_ENTRY_OUTPUT] = arg
             elif arg in ("y", "y_hat", "label", "labels", "target",
                          "targets"):
                 args_dict[DEEP_ENTRY_LABEL] = arg
             else:
                 Notification(DEEP_NOTIF_FATAL,
                              DEEP_MSG_LOSS_UNEXPECTED_ARG % arg)
     self.check_essential_args(args_dict)
     return {
         arg_name: entry_flag
         for entry_flag, arg_name in args_dict.items()
     }
Beispiel #17
0
 def summary(self):
     Notification(DEEP_NOTIF_INFO, '================================================================')
     Notification(DEEP_NOTIF_INFO, "OPTIMIZER : %s from %s" % (self.name, self.module))
     Notification(DEEP_NOTIF_INFO, '================================================================')
     for key, value in self.kwargs.items():
         if key != "name":
             Notification(DEEP_NOTIF_INFO, "%s : %s" % (key, value))
     Notification(DEEP_NOTIF_INFO, "")
Beispiel #18
0
    def summary(self):

        Notification(DEEP_NOTIF_INFO, "------------------------------------")
        Notification(DEEP_NOTIF_INFO,
                     "Transformer '" + str(self.name) + "' summary :")
        Notification(DEEP_NOTIF_INFO,
                     "Points to : " + str(self.pointer_to_transformer))
        Notification(DEEP_NOTIF_INFO, "------------------------------------")
    def __check_move_axis(
            self, move_axis: Optional[List[int]]) -> Optional[List[int]]:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check that the move_axis argument is a list of integer or None

        PARAMETERS:
        -----------

        :param move_axis (Optional[List[int]]): The new order of axis

        RETURN:
        -------
        :return (Optional[List[int]]):
        """
        # Check if None
        if move_axis is None:
            return None

        if isinstance(move_axis, list):
            for i in move_axis:
                if isinstance(i, int) is False:

                    # Get index of the PipelineEntry instance (throught its weakref)
                    entry_info = self.pipeline_entry().get_index()

                    # Get info of the Dataset (throught its weakref) from the weakref of the PipelineEntry instance
                    dataset_info = self.pipeline_entry().get_dataset()(
                    ).get_info()

                    Notification(
                        DEEP_NOTIF_FATAL,
                        "Please check the value of the move_axis argument. in the Entry %i of the Dataset %s"
                        % (entry_info, dataset_info),
                        solutions=
                        "One of the item in the list is not an integer")

            return move_axis
        else:
            # Get index of the PipelineEntry instance (throught its weakref)
            entry_info = self.pipeline_entry().get_index()

            # Get info of the Dataset (throught its weakref) from the weakref of the PipelineEntry instance
            dataset_info = self.pipeline_entry().get_dataset().get_info()

            Notification(
                DEEP_NOTIF_FATAL,
                "Please check the value of the move_axis argument in the Entry %i of the Dataset %s"
                % (entry_info, dataset_info),
                solutions=
                "The move_axis argument must be a list of integers index at 0")
Beispiel #20
0
    def load_optimizer(self):
        """
        AUTHORS:
        --------

        :author: Samuel Westlake
        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the optimizer with the adequate parameters

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If a module is specified, edit the optimizer name to include the module (for notification purposes)
        if self.config.optimizer.module is None:
            optimizer_path = "%s from default modules" % self.config.optimizer.name
        else:
            optimizer_path = "%s from %s" % (self.config.optimizer.name,
                                             self.config.optimizer.module)

        # Notify the user which model is being collected and from where
        Notification(DEEP_NOTIF_INFO, DEEP_MSG_OPTIM_LOADING % optimizer_path)

        # An optimizer cannot be loaded without a model (self.model.parameters() is required)
        if self.model is not None:
            # Load the optimizer
            optimizer = load_optimizer(
                model_parameters=self.model.parameters(),
                **self.config.optimizer.get())
            msg = "%s from %s" % (self.config.optimizer.name, optimizer.module)
            if self.config.model.from_file:
                checkpoint = self.__load_checkpoint()
                if "optimizer_state_dict" in checkpoint:
                    optimizer.load_state_dict(
                        checkpoint["optimizer_state_dict"])
                    msg += " with state dict from %s" % self.config.model.file

            # If model exists, load the into the frontal lobe
            if optimizer is None:
                Notification(DEEP_NOTIF_FATAL,
                             DEEP_MSG_OPTIM_NOT_FOUND % optimizer_path)
            else:
                self.optimizer = optimizer
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_OPTIM_LOADED % msg)

        # Notify the user that a model must be loaded
        else:
            Notification(DEEP_NOTIF_FATAL, DEEP_MSG_OPTIM_MODEL_NOT_LOADED)
Beispiel #21
0
    def load_validator(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the validation inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the validation step is enabled
        if self.config.data.enabled.validation:

            validation_index = self.get_dataset_index("validation")

            Notification(
                DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING %
                self.config.data.datasets[validation_index].name)

            # Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.validation.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.validation.get(
                    "outputs"))
            # output_transformer.summary()

            # Dataset
            dataset = Dataset(**self.config.data.datasets[validation_index].
                              get(ignore="type"),
                              transform_manager=transform_manager)

            # Validator
            self.validator = Tester(**self.config.data.dataloader.get(),
                                    model=self.model,
                                    dataset=dataset,
                                    metrics=self.metrics,
                                    losses=self.losses,
                                    transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % "Validation set")
 def __run_project(self):
     Notification(DEEP_NOTIF_WARNING, "The run project command does not support custom modules")
     Notification(DEEP_NOTIF_WARNING, "If you need to import custom modules use : python3 main.py")
     try:
         config_dir = self.argv[2]
     except IndexError:
         config_dir = "./config"
     brain = Brain(config_dir=config_dir)
     brain.wake()
Beispiel #23
0
    def load_metrics(self):
        """
        AUTHORS:
        --------
        :author: Alix Leroy

        DESCRIPTION:
        ------------
        Load the metrics

        PARAMETERS:
        -----------
        None

        RETURN:
        -------
        :return loss_functions->dict: The metrics
        """
        metrics = {}
        if self.config.metrics.get():
            for key, config in self.config.metrics.get().items():

                # Get the expected loss path (for notification purposes)
                if config.module is None:
                    metric_path = "%s : %s from default modules" % (key, config.name)
                else:
                    metric_path = "%s : %s from %s" % (key, config.name, config.module)

                # Notify the user which loss is being collected and from where
                Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_LOADING % metric_path)

                # Get the metric object
                metric, module = get_module(
                    name=config.name,
                    module=config.module,
                    browse={**DEEP_MODULE_METRICS, **DEEP_MODULE_LOSSES},
                    silence=True
                )

                # If metric is not found by get_module
                if metric is None:
                    Notification(DEEP_NOTIF_FATAL, DEEP_MSG_METRIC_NOT_FOUND % config.name)

                # Check if the metric is a class or a stand-alone function
                if inspect.isclass(metric):
                    method = metric(**config.kwargs.get())
                else:
                    method = metric

                # Add to the dictionary of metrics and notify of success
                metrics[str(key)] = Metric(name=str(key), method=method)
                metrics[str(key)] = Metric(name=str(key), method=method)
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_METRIC_LOADED % (key, config.name, module))
        else:
            Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_NONE)

        self.metrics = Metrics(metrics)
Beispiel #24
0
    def load_tester(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the test inferer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        # If the test step is enabled
        if self.config.data.enabled.test:
            self.loading_message("Tester")
            i = self.get_dataset_index(DEEP_DATASET_TEST)
            Notification(
                DEEP_NOTIF_INFO,
                DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name)

            # Input Transform Manager
            transform_manager = TransformManager(
                **self.config.transform.test.get(ignore="outputs"))

            # Output Transformer
            output_transform_manager = OutputTransformer(
                transform_files=self.config.transform.test.get("outputs"))

            # Initialise test dataset
            dataset = Dataset(**self.config.data.datasets[i].get(
                ignore=["batch_size"]),
                              transform_manager=transform_manager)

            # Initialise tester
            self.tester = Tester(
                **self.config.data.dataloader.get(),
                batch_size=self.config.data.datasets[i].batch_size,
                model=self.model,
                dataset=dataset,
                metrics=self.metrics,
                losses=self.losses,
                transform_manager=output_transform_manager)
        else:
            Notification(DEEP_NOTIF_INFO,
                         DEEP_MSG_DATA_DISABLED % DEEP_DATASET_TEST.name)
Beispiel #25
0
    def load_losses(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the losses

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return loss_functions->dict: The losses
        """
        losses = {}
        for key, value in self.config.losses.get().items():
            loss = get_module(config=value, modules=DEEP_MODULE_LOSSES)
            kwargs = check_kwargs(get_kwargs(self.config.losses.get(key)))
            method = loss(**kwargs)
            # Check the weight
            if self.config.losses.check("weight", key):
                if get_int_or_float(value.weight) not in (DEEP_TYPE_INTEGER,
                                                          DEEP_TYPE_FLOAT):
                    Notification(
                        DEEP_NOTIF_FATAL,
                        "The loss function %s doesn't have a correct weight argument"
                        % key)
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The loss function %s doesn't have any weight argument" %
                    key)

            # Create the loss
            if isinstance(method, torch.nn.Module):
                losses[str(key)] = Loss(name=str(key),
                                        weight=float(value.weight),
                                        loss=method)
                Notification(
                    DEEP_NOTIF_SUCCESS,
                    DEEP_MSG_LOSS_LOADED % (key, value.name, loss.__module__))
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The loss function %s is not a torch.nn.Module instance" %
                    key)
        self.losses = losses
    def __check_load_method(self, load_method):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check if the load method required is valid.
        Check if the load method given is an integer, otherwise convert it to the corresponding flag

        PARAMETERS:
        -----------

        :param load_method:

        RETURN:
        -------

        :return load_method (int): The corresponding DEEP_LOAD_METHOD flag
        """

        if isinstance(load_method, int):
            if load_method in DEEP_LOAD_METHOD_LIST:
                return load_method
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The given flag '%i' does not correspond to any DEEP_LOAD_METHOD"
                    % load_method)
        else:
            if load_method == "default":
                return DEEP_LOAD_METHOD_MEMORY
            elif load_method == "memory":
                return DEEP_LOAD_METHOD_MEMORY
            elif load_method in [
                    "harddrive", "hard drive", "hard-drive", "hard_drive"
            ]:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "Loading data using a hard drive reading is not currently implemented"
                )
                return DEEP_LOAD_METHOD_HARDDRIVE
            elif load_method == "server":
                Notification(
                    DEEP_NOTIF_FATAL,
                    "Loading data using a server is not currently implemented")
                return DEEP_LOAD_METHOD_SERVER
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The following loading method does not exist : %s" %
                    str(load_method))
Beispiel #27
0
    def __check_sources(
            self, sources: Union[str, List[str]],
            join: Union[str, None, List[Union[str, None]]]) -> List[Source]:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check the source given is a string or a list of strings
        If it is a single source, we transform it to a list of a single element

        PARAMETERS:
        -----------

        :param sources (Union[str, List[str]]) : The source or list of sources
        :param join (Union[str, List[str]) : The list of jointures

        RETURN:
        -------

        :return sources(List[Source]): List of sources
        """
        formatted_sources = []

        # Convert single elements to a list of one element
        if isinstance(sources, list) is False:
            sources = [sources]

        # Check joins
        join = self.__check_join(join=join)

        # Check the number of jointures and sources are equal:
        if len(sources) != len(join):
            Notification(
                DEEP_NOTIF_FATAL,
                "The entry %s (index %i) does not have the same amount of sources and jointures"
                % (self.entry_type.get_description(), self.entry_index))

        # Check all the elements in the list to make sure they are strings
        for i, s in enumerate(sources):
            if isinstance(s, str) is True:
                source = Source(source=s, join=join[i])
                formatted_sources.append(source)
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The source parameter '%s' in the %i th %s entry is not valid"
                    % (str(s), self.entry_type.get_description(),
                       self.entry_index))

        return formatted_sources
 def summary(self, level=0):
     Notification(DEEP_NOTIF_INFO, "%sname: %s" % ("  " * level, self.name))
     Notification(DEEP_NOTIF_INFO,
                  "%smodule path: %s" % ("  " * level, self.module_path))
     Notification(DEEP_NOTIF_INFO,
                  "%sweight: %s" % ("  " * level, self.f2str(self.weight)))
     if bool(self.args):
         Notification(DEEP_NOTIF_INFO, "%sargs: " % "  " * level)
         for arg_name, entry_flag in self.args.items():
             Notification(
                 DEEP_NOTIF_INFO, "%s%s (%s)" %
                 ("  " * (level + 1), arg_name, entry_flag.name))
     else:
         Notification(
             DEEP_NOTIF_INFO,
             "%sargs: []" % "  " * level,
         )
     if bool(self.kwargs):
         Notification(DEEP_NOTIF_INFO, "%skwargs: " % "  " * level)
         for kwarg_name, kwarg_value in self.kwargs.items():
             Notification(
                 DEEP_NOTIF_INFO,
                 "%s%s: %s" % ("  " * (level + 1), kwarg_name, kwarg_value))
     else:
         Notification(DEEP_NOTIF_INFO, "%skwargs: {}" % ("  " * level))
 def load_weights(self, weights):
     if weights is not None:
         Notification(DEEP_NOTIF_INFO, "Initializing network weights from file")
         try:
             self.load_state_dict(torch.load(weights)['model_state_dict'], strict=True)
             Notification(DEEP_NOTIF_SUCCESS, "Successfully loaded weights from %s" % weights)
         except RuntimeError as e:
             Notification(DEEP_NOTIF_WARNING, " :" .join(str(e).split(":")[1:]).strip())
             Notification(DEEP_NOTIF_WARNING, "Ignoring unexpected key(s)")
             self.load_state_dict(torch.load(weights)['model_state_dict'], strict=False)
             Notification(DEEP_NOTIF_SUCCESS, "Successfully loaded weights from %s" % weights)
 def evaluation_end(self, silent: bool = False):
     self.transform_manager.finish()  # Call finish on all output transforms
     loss, losses = self.losses.reduce(
         self.dataset.type)  # Get total loss and mean of each loss
     metrics = self.metrics.reduce(
         self.dataset.type)  # Get total metric values
     if not silent:
         Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_EVALUATION_FINISHED)
         Notification(DEEP_NOTIF_RESULT,
                      self.compose_text(loss, losses, metrics))
     return loss, losses, metrics