Ejemplo n.º 1
0
    def load(config, model_parameters):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the optimizer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        optimizer = get_module(config=config, modules=DEEP_MODULE_OPTIMIZERS)
        kwargs = check_kwargs(config.kwargs)
        return optimizer(model_parameters, **kwargs)
Ejemplo n.º 2
0
    def load(config: Namespace):
        """
        AUTHORS:
        --------

        :author: Samuel Westlake
        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the model

        PARAMETERS:
        -----------

        :param config(Namespace): The parameters from the model config file

        RETURN:
        -------

        :return: None
        """
        model = get_module(config=config, modules=DEEP_MODULE_MODELS)
        kwargs = check_kwargs(config.kwargs)
        return model(**kwargs)
Ejemplo n.º 3
0
    def __fill_transform_list(self, config_transforms: Namespace):
        """
        AUTHORS:
        --------

        author: Alix Leroy

        DESCRIPTION:
        ------------

        Fill the list of transforms with the corresponding methods and arguments

        PARAMETERS:
        -----------

        :param transforms-> list: A list of transforms

        RETURN:
        -------

        :return: None
        """
        list_transforms = []

        for key, value in config_transforms.get().items():
            list_transforms.append([
                key,
                get_module(value, modules=DEEP_MODULE_TRANSFORMS),
                check_kwargs(get_kwargs(value.get()))
            ])
        return list_transforms
Ejemplo n.º 4
0
    def load_losses(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the losses

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return loss_functions->dict: The losses
        """
        losses = {}
        for key, value in self.config.losses.get().items():
            loss = get_module(config=value, modules=DEEP_MODULE_LOSSES)
            kwargs = check_kwargs(get_kwargs(self.config.losses.get(key)))
            method = loss(**kwargs)
            # Check the weight
            if self.config.losses.check("weight", key):
                if get_int_or_float(value.weight) not in (DEEP_TYPE_INTEGER,
                                                          DEEP_TYPE_FLOAT):
                    Notification(
                        DEEP_NOTIF_FATAL,
                        "The loss function %s doesn't have a correct weight argument"
                        % key)
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The loss function %s doesn't have any weight argument" %
                    key)

            # Create the loss
            if isinstance(method, torch.nn.Module):
                losses[str(key)] = Loss(name=str(key),
                                        weight=float(value.weight),
                                        loss=method)
                Notification(
                    DEEP_NOTIF_SUCCESS,
                    DEEP_MSG_LOSS_LOADED % (key, value.name, loss.__module__))
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The loss function %s is not a torch.nn.Module instance" %
                    key)
        self.losses = losses
Ejemplo n.º 5
0
    def load_metrics(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the metrics

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return loss_functions->dict: The metrics
        """
        metrics = {}
        for key, value in self.config.metrics.get().items():
            metric = get_module(config=value, modules=DEEP_MODULE_METRICS)
            kwargs = check_kwargs(get_kwargs(self.config.metrics.get(key)))
            if inspect.isclass(metric):
                method = metric(**kwargs)
            else:
                method = metric

            metrics[str(key)] = Metric(name=str(key), method=method)
            Notification(
                DEEP_NOTIF_SUCCESS,
                DEEP_MSG_METRIC_LOADED % (key, value.name, metric.__module__))

        self.metrics = metrics