Esempio n. 1
0
 def load_transform_functions(self):
     """
     Calls the get_module function for each transform in each transformer
     Loads the returned method and edits the module entry to reflect the origin
     :return: None
     """
     # For each transform sequence
     for i, sequence in enumerate(self.output_transformer):
         # Check that the transforms entry exists
         self.__check_transforms_exists(sequence, i)
         if sequence.transforms is not None:
             for transform_name, transform_info in sequence.transforms.get(
             ).items():
                 self.__transform_loading(transform_name, transform_info)
                 method, module_path = get_module(
                     **transform_info.get(ignore="kwargs"),
                     browse=DEEP_MODULE_TRANSFORMS)
                 if method is None:
                     self.__method_not_found(transform_info)
                 if isinstance(method, types.FunctionType):
                     transform_info.add({"method": method})
                 else:
                     try:
                         transform_info.add({
                             "method":
                             method(**transform_info.kwargs.get())
                         })
                     except TypeError as e:
                         Notification(DEEP_NOTIF_FATAL, str(e))
                 transform_info.module = module_path
                 Notification(
                     DEEP_NOTIF_SUCCESS,
                     "Loaded transform : %s : %s from %s" %
                     (transform_name, transform_info.name,
                      transform_info.module))
Esempio n. 2
0
    def load(config: Namespace):
        """
        AUTHORS:
        --------

        :author: Samuel Westlake
        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the model

        PARAMETERS:
        -----------

        :param config(Namespace): The parameters from the model config file

        RETURN:
        -------

        :return: None
        """
        model = get_module(config=config, modules=DEEP_MODULE_MODELS)
        kwargs = check_kwargs(config.kwargs)
        return model(**kwargs)
Esempio n. 3
0
    def __init__(self,
                 name: str,
                 module_path: Union[str, None],
                 weight: float,
                 kwargs: dict = None):
        # Get the metric object
        method, module_path = get_module(name=name,
                                         module=module_path,
                                         browse={**DEEP_MODULE_LOSSES})

        # If metric is not found by get_module, raise DEEP_FATAL
        if method is None:
            raise DeepError(DEEP_MSG_METRIC_NOT_FOUND % name)

        # If method is a class, initialise it with metric.kwargs
        kwargs = {} if kwargs is None else kwargs
        if inspect.isclass(method):
            method = method(**kwargs)

        self.name = name
        self.module_path = module_path
        self.method = method
        self.weight = weight
        self.is_custom = self.check_custom()
        self.args = self.__check_args()
        self.kwargs = kwargs
Esempio n. 4
0
    def __fill_transform_list(self, config_transforms: Namespace):
        """
        AUTHORS:
        --------

        author: Alix Leroy

        DESCRIPTION:
        ------------

        Fill the list of transforms with the corresponding methods and arguments

        PARAMETERS:
        -----------

        :param transforms-> list: A list of transforms

        RETURN:
        -------

        :return: None
        """
        list_transforms = []

        for key, value in config_transforms.get().items():
            list_transforms.append([
                key,
                get_module(value, modules=DEEP_MODULE_TRANSFORMS),
                check_kwargs(get_kwargs(value.get()))
            ])
        return list_transforms
    def generate_sources(self, sources: List[dict]) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Generate the sources
        Does not generate the SourcePointer instances

        PARAMETERS:
        -----------

        :param sources(List[dict]): The configuration of the Source instances to generate

        RETURN:
        -------

        :return: None
        """
        list_sources = []
        # Create sources
        for i, source in enumerate(sources):

            # Get the Source module and add it to the list
            m = "default modules" if source["module"] is None else source["module"]
            Notification(DEEP_NOTIF_INFO, DEEP_MSG_LOADING % ("source [%i]" % i, source["name"], m))
            method, module_path = get_module(name=source["name"], module=source["module"], browse=DEEP_MODULE_SOURCES)
            if method is None:
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_MODULE_NOT_FOUND % (source["name"], module_path))
            else:
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_LOADED % ("source [%i]" % i, source["name"], module_path))

            # If the source is not a real source
            if not issubclass(method, Source):
                # Remove the id from the initial kwargs
                index = source["kwargs"].pop('index', None)

                # Create a source wrapper with the new ID
                s = SourceWrapper(
                    index=index,
                    name=source["name"],
                    module=source["module"],
                    kwargs=source["kwargs"]
                )
            else:
                # If the subclass is a real Source
                s = method(**source["kwargs"])

            # Check the module inherits the generic Source class
            self.check_type_sources(s, i)

            # Add the module to the list of Source instances
            list_sources.append(s)

        # Set the list as the attribute
        self.sources = list_sources
 def __import(self, *args):
     for arg in args:
         items = []
         for m in DEEP_LIST_MODULE:
             method, module_path = get_module(arg, browse=m)
             if method is not None:
                 items.append({
                     "method": method,
                     "module_path": module_path,
                     "file": inspect.getsourcefile(method),
                     "prefix": m["custom"]["prefix"],
                     "name": arg
                 })
         if not items:
             Notification(DEEP_NOTIF_ERROR, "module not found : %s" % arg)
         elif len(items) == 1:
             self.__import_module(items[0])
         else:
             Notification(DEEP_NOTIF_INFO,
                          "Multiple modules found with the same name :")
             for i, item in enumerate(items):
                 Notification(
                     DEEP_NOTIF_INFO,
                     "  %i. %s from %s" % (i, arg, item["module_path"]))
             while True:
                 i = Notification(DEEP_NOTIF_INPUT,
                                  "Which would you like?").get()
                 try:
                     i = int(i)
                     break
                 except ValueError:
                     pass
             self.__import_module(items[i])
Esempio n. 7
0
    def load(config, model_parameters):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Load the optimizer in memory

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return: None
        """
        optimizer = get_module(config=config, modules=DEEP_MODULE_OPTIMIZERS)
        kwargs = check_kwargs(config.kwargs)
        return optimizer(model_parameters, **kwargs)
Esempio n. 8
0
 def __init__(self,
              name: str,
              module_path: Union[str, None],
              reduce: str = "mean",
              ignore_value: float = None,
              kwargs: dict = None):
     # Get the metric object
     method, module_path = get_module(name=name,
                                      module=module_path,
                                      browse={
                                          **DEEP_MODULE_METRICS,
                                          **DEEP_MODULE_LOSSES
                                      })
     # If metric is not found by get_module, raise DEEP_FATAL
     if method is None:
         raise DeepError(DEEP_MSG_METRIC_NOT_FOUND % name)
     # If method is a class, initialise it with metric.kwargs
     kwargs = {} if kwargs is None else kwargs
     if inspect.isclass(method):
         method = method(**kwargs)
     self.name = name
     self.module_path = module_path
     self.method = method
     self.args = self.__check_args()
     self.kwargs = kwargs
     self.reduce_method = None
     self.ignore_value = ignore_value
     self.reduce = reduce
Esempio n. 9
0
def load_optimizer(name, module, model_parameters, kwargs):
    optimizer, module = get_module(name=name,
                                   module=module,
                                   browse=DEEP_MODULE_OPTIMIZERS)

    class Optimizer(optimizer):
        def __init__(self, name, module, model_parameters, kwargs_dict,
                     **kwargs):
            super(Optimizer, self).__init__(model_parameters, **kwargs)
            self.name = name
            self.module = module
            self.kwargs = kwargs_dict

        def summary(self):
            Notification(
                DEEP_NOTIF_INFO,
                '================================================================'
            )
            Notification(DEEP_NOTIF_INFO,
                         "OPTIMIZER : %s from %s" % (self.name, self.module))
            Notification(
                DEEP_NOTIF_INFO,
                '================================================================'
            )
            for key, value in self.kwargs.items():
                if key != "name":
                    Notification(DEEP_NOTIF_INFO, "%s : %s" % (key, value))
            Notification(DEEP_NOTIF_INFO, "")

    return Optimizer(name, module, model_parameters, kwargs.get(),
                     **kwargs.get())
Esempio n. 10
0
    def load_metrics(self):
        """
        AUTHORS:
        --------
        :author: Alix Leroy

        DESCRIPTION:
        ------------
        Load the metrics

        PARAMETERS:
        -----------
        None

        RETURN:
        -------
        :return loss_functions->dict: The metrics
        """
        metrics = {}
        if self.config.metrics.get():
            for key, config in self.config.metrics.get().items():

                # Get the expected loss path (for notification purposes)
                if config.module is None:
                    metric_path = "%s : %s from default modules" % (key, config.name)
                else:
                    metric_path = "%s : %s from %s" % (key, config.name, config.module)

                # Notify the user which loss is being collected and from where
                Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_LOADING % metric_path)

                # Get the metric object
                metric, module = get_module(
                    name=config.name,
                    module=config.module,
                    browse={**DEEP_MODULE_METRICS, **DEEP_MODULE_LOSSES},
                    silence=True
                )

                # If metric is not found by get_module
                if metric is None:
                    Notification(DEEP_NOTIF_FATAL, DEEP_MSG_METRIC_NOT_FOUND % config.name)

                # Check if the metric is a class or a stand-alone function
                if inspect.isclass(metric):
                    method = metric(**config.kwargs.get())
                else:
                    method = metric

                # Add to the dictionary of metrics and notify of success
                metrics[str(key)] = Metric(name=str(key), method=method)
                metrics[str(key)] = Metric(name=str(key), method=method)
                Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_METRIC_LOADED % (key, config.name, module))
        else:
            Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_NONE)

        self.metrics = Metrics(metrics)
Esempio n. 11
0
    def load_losses(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the losses

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return loss_functions->dict: The losses
        """
        losses = {}
        for key, value in self.config.losses.get().items():
            loss = get_module(config=value, modules=DEEP_MODULE_LOSSES)
            kwargs = check_kwargs(get_kwargs(self.config.losses.get(key)))
            method = loss(**kwargs)
            # Check the weight
            if self.config.losses.check("weight", key):
                if get_int_or_float(value.weight) not in (DEEP_TYPE_INTEGER,
                                                          DEEP_TYPE_FLOAT):
                    Notification(
                        DEEP_NOTIF_FATAL,
                        "The loss function %s doesn't have a correct weight argument"
                        % key)
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The loss function %s doesn't have any weight argument" %
                    key)

            # Create the loss
            if isinstance(method, torch.nn.Module):
                losses[str(key)] = Loss(name=str(key),
                                        weight=float(value.weight),
                                        loss=method)
                Notification(
                    DEEP_NOTIF_SUCCESS,
                    DEEP_MSG_LOSS_LOADED % (key, value.name, loss.__module__))
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The loss function %s is not a torch.nn.Module instance" %
                    key)
        self.losses = losses
Esempio n. 12
0
    def generate_sources(self, sources: List[dict]) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Generate the sources
        Does not generate the SourcePointer instances

        PARAMETERS:
        -----------

        :param sources(List[dict]): The configuration of the Source instances to generate

        RETURN:
        -------

        :return: None
        """
        list_sources = []

        # Create sources
        for i, source in enumerate(sources):

            # Get the Source module and add it to the list
            module, origin = get_module(name=source["name"],
                                        module=source["module"],
                                        browse=DEEP_MODULE_SOURCES)

            # If the source is not a real source
            if issubclass(module, Source) is False:
                # Remove the id from the initial kwargs
                index = source["kwargs"].pop('index', None)

                # Create a source wrapper with the new ID
                s = SourceWrapper(index=index,
                                  name=source["name"],
                                  module=source["module"],
                                  kwargs=source["kwargs"])
            else:
                # If the subclass is a real Source
                s = module(**source["kwargs"])

            # Check the module inherits the generic Source class
            self.check_type_sources(s, i)

            # Add the module to the list of Source instances
            list_sources.append(s)

        # Set the list as the attribute
        self.sources = list_sources
Esempio n. 13
0
    def __fill_transform_list(
            transforms: Union[Namespace, List[dict]]) -> list:
        """
        AUTHORS:
        --------

        author: Alix Leroy

        DESCRIPTION:
        ------------

        Fill the list of transforms with the corresponding methods and arguments

        PARAMETERS:
        -----------

        :param transforms (Union[Namespace, list): A list of transforms

        RETURN:
        -------

        :return loaded_transforms (list): The list of loaded transforms
        """

        loaded_transforms = []
        if transforms is not None:
            for transform in transforms:

                # Switch from Namespace to dict
                #transform = transform.get_all()    # ONly when there is no name
                transform_name = list(transform.get_all())[0]
                transform = transform.get_all()[transform_name]

                if "module" not in transform:
                    transform["module"] = None

                module, module_path = get_module(name=transform["name"],
                                                 module=transform["module"],
                                                 browse=DEEP_MODULE_TRANSFORMS)

                t = TransformData(name=transform["name"],
                                  method=module,
                                  module_path=module_path,
                                  kwargs=transform["kwargs"])
                loaded_transforms.append(t)

        return loaded_transforms
Esempio n. 14
0
 def load_scheduler(self):
     Notification(
         DEEP_NOTIF_INFO, "Loading learn rate scheduler : %s from %s" %
         (self.config.training.scheduler.name,
          self.config.training.scheduler.module))
     scheduler, scheduler_module = get_module(
         **self.config.training.scheduler.get(ignore=["kwargs", "enabled"]))
     if scheduler is not None:
         self.scheduler = scheduler(
             self.optimizer, **vars(self.config.training.scheduler.kwargs))
         Notification(
             DEEP_NOTIF_SUCCESS,
             "Loaded learn rate scheduler : %s from %s" %
             (self.config.training.scheduler.name, scheduler_module))
     else:
         Notification(
             DEEP_NOTIF_FATAL,
             "Unable to import learn rate scheduler : %s from %s" %
             (self.config.training.scheduler.name,
              self.config.training.scheduler.module))
Esempio n. 15
0
    def __fill_transform_list(
            transforms: Union[Namespace, List[dict]]) -> list:
        """
        AUTHORS:
        --------

        author: Alix Leroy

        DESCRIPTION:
        ------------

        Fill the list of transforms with the corresponding methods and arguments

        PARAMETERS:
        -----------

        :param transforms (Union[Namespace, list): A list of transforms

        RETURN:
        -------

        :return loaded_transforms (list): The list of loaded transforms
        """

        loaded_transforms = []
        if transforms is not None:
            for transform in transforms:
                if "module" not in transform:
                    transform["module"] = None

                module, module_path = get_module(name=transform["name"],
                                                 module=transform["module"],
                                                 browse=DEEP_MODULE_TRANSFORMS)

                loaded_transforms.append({
                    "name": transform["name"],
                    "method": module,
                    "module_path": module_path,
                    "kwargs": transform["kwargs"]
                })
        return loaded_transforms
Esempio n. 16
0
    def load_metrics(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the metrics

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return loss_functions->dict: The metrics
        """
        metrics = {}
        for key, value in self.config.metrics.get().items():
            metric = get_module(config=value, modules=DEEP_MODULE_METRICS)
            kwargs = check_kwargs(get_kwargs(self.config.metrics.get(key)))
            if inspect.isclass(metric):
                method = metric(**kwargs)
            else:
                method = metric

            metrics[str(key)] = Metric(name=str(key), method=method)
            Notification(
                DEEP_NOTIF_SUCCESS,
                DEEP_MSG_METRIC_LOADED % (key, value.name, metric.__module__))

        self.metrics = metrics
Esempio n. 17
0
    def __init__(self,
                 name: str,
                 module: str,
                 kwargs: dict,
                 index: int = -1,
                 is_loaded: bool = True,
                 is_transformed: bool = False,
                 num_instances: Optional[int] = None,
                 instance_id: int = 0,
                 instance_indices: Optional[List[int]] = None):

        super().__init__(index=index,
                         is_loaded=is_loaded,
                         is_transformed=is_transformed,
                         num_instances=num_instances,
                         instance_id=instance_id)

        # Module wrapped and its origin
        module, self.origin = get_module(module=module, name=name)
        # Load module
        self.module = module(**kwargs)

        # Index of the desired source
        self.instance_indices = instance_indices
Esempio n. 18
0
    def __fill_transform_list(
            transforms: Union[Namespace, List[dict]]) -> list:
        """
        AUTHORS:
        --------

        author: Alix Leroy

        DESCRIPTION:
        ------------

        Fill the list of transforms with the corresponding methods and arguments

        PARAMETERS:
        -----------

        :param transforms (Union[Namespace, list): A list of transforms

        RETURN:
        -------

        :return loaded_transforms (list): The list of loaded transforms
        """

        loaded_transforms = []
        if transforms is not None:
            for transform in transforms:

                # Switch from Namespace to dict
                #transform = transform.get_all()    # ONly when there is no name
                transform_name = list(transform.get_all())[0]
                transform = transform.get_all()[transform_name]

                if "module" not in transform:
                    transform["module"] = None

                m = "default modules" if transform[
                    "module"] is None else transform["module"]
                Notification(
                    DEEP_NOTIF_INFO,
                    DEEP_MSG_LOADING % ("transform", transform["name"], m))

                method, module_path = get_module(name=transform["name"],
                                                 module=transform["module"],
                                                 browse=DEEP_MODULE_TRANSFORMS)
                if method is not None:
                    Notification(
                        DEEP_NOTIF_SUCCESS, DEEP_MSG_LOADED %
                        ("transform", transform["name"], module_path))
                else:
                    Notification(
                        DEEP_NOTIF_FATAL, DEEP_MSG_MODULE_NOT_FOUND %
                        ("transform", transform["name"], m))

                t = TransformData(name=transform["name"],
                                  method=method,
                                  module_path=module_path,
                                  kwargs=transform["kwargs"])
                loaded_transforms.append(t)

        return loaded_transforms
Esempio n. 19
0
def load_model(name,
               module,
               kwargs,
               device,
               model_state_dict=None,
               device_ids=None,
               input_size=None,
               batch_size=None):
    # Get the model, should be nn.Module
    module, origin = get_module(name=name,
                                module=module,
                                browse=DEEP_MODULE_MODELS)
    if module is None:
        return None

    # Define our model that inherits from the nn.Module
    class Model(module):
        def __init__(self,
                     name,
                     origin,
                     device,
                     device_ids=None,
                     input_size=None,
                     batch_size=None,
                     model_dict=None,
                     **kwargs):
            super(Model, self).__init__(**kwargs)
            self.name = name
            self.origin = origin
            self.device_ids = device_ids
            self.input_size = input_size
            self.batch_size = batch_size
            self.model_dict = {} if model_dict is None else model_dict
            self.device = device

        def summary(self):
            """
            AUTHORS:
            --------
            :author: Alix Leroy
            :author: Samuel Westlake
            DESCRIPTION:
            ------------
            Summarise the model
            PARAMETERS:
            -----------
            None
            RETURN:
            -------
            :return: None
            """

            if self.input_size is not None:
                self.__summary()
            else:
                Notification(
                    DEEP_NOTIF_ERROR,
                    "Model's input size not given, the summary cannot be displayed."
                )

        def __summary(self):
            """
            AUTHORS:
            --------
            :author:  https://github.com/sksq96/pytorch-summary
            :author: Alix Leroy
            DESCRIPTION:
            ------------
            Print a summary of the current model
            PARAMETERS:
            -----------
            None
            RETURN:
            -------
            :return: None
            """
            def register_hook(module):
                def hook(module, input, output):
                    class_name = str(
                        module.__class__).split(".")[-1].split("'")[0]
                    module_idx = len(summary)
                    m_key = "%s-%i" % (class_name, module_idx + 1)
                    summary[m_key] = OrderedDict()
                    if isinstance(input[0], list) is True:
                        summary[m_key]["input_shape"] = list(
                            input[0][0].size())
                    else:
                        summary[m_key]["input_shape"] = list(input[0].size())
                    summary[m_key]["input_shape"][0] = self.batch_size
                    if isinstance(output, (list, tuple)):
                        summary[m_key]["output_shape"] = [[-1] +
                                                          list(o.size())[1:]
                                                          for o in output]
                    else:
                        summary[m_key]["output_shape"] = list(output.size())
                        summary[m_key]["output_shape"][0] = self.batch_size
                    params = 0
                    if hasattr(module, "weight") and hasattr(
                            module.weight, "size"):
                        params += torch.prod(
                            torch.LongTensor(list(module.weight.size())))
                        summary[m_key][
                            "trainable"] = module.weight.requires_grad
                    if hasattr(module, "bias") and hasattr(
                            module.bias, "size"):
                        params += torch.prod(
                            torch.LongTensor(list(module.bias.size())))
                    summary[m_key]["nb_params"] = params

                if (not isinstance(module, nn.Sequential)
                        and not isinstance(module, nn.ModuleList)
                        and not (module == model)):
                    hooks.append(module.register_forward_hook(hook))

            # Batch_size of 2 for batchnorm
            x = [torch.rand(2, *in_size) for in_size in input_size]

            # Move the batch to the same device as the model
            x = [i.to(self.device) for i in x]

            # Create properties
            summary = OrderedDict()
            hooks = []

            # Register hook
            self.apply(register_hook)

            # Make a forward pass
            self.forward(*x)
            try:
                self.forward(*x)
            except RuntimeError as e:
                if "channels" in str(e):
                    Notification(DEEP_NOTIF_FATAL,
                                 str(e),
                                 solutions=[DEEP_MSG_MODEL_CHECK_CHANNELS])
                else:
                    Notification(DEEP_NOTIF_FATAL, str(e))

            # Remove these hooks
            for h in hooks:
                h.remove()

            Notification(
                DEEP_NOTIF_INFO,
                '================================================================'
            )
            Notification(DEEP_NOTIF_INFO,
                         "MODEL : %s from %s" % (self.name, self.origin))
            Notification(
                DEEP_NOTIF_INFO,
                '================================================================'
            )
            line_new = '{:>20}  {:>25} {:>15}'.format('Layer (type)',
                                                      'Output Shape',
                                                      'Param #')
            Notification(DEEP_NOTIF_INFO, line_new)
            Notification(
                DEEP_NOTIF_INFO,
                '----------------------------------------------------------------'
            )
            total_params = 0
            total_output = 0
            trainable_params = 0
            for layer in summary:
                # Input_shape, output_shape, trainable, nb_params
                line_new = '{:>20}  {:>25} {:>15}'.format(
                    layer, str(summary[layer]['output_shape']),
                    '{0:,}'.format(summary[layer]['nb_params']))
                total_params += summary[layer]['nb_params']
                total_output += np.prod(summary[layer]["output_shape"])
                if 'trainable' in summary[layer]:
                    if summary[layer]['trainable'] == True:
                        trainable_params += summary[layer]['nb_params']
                Notification(DEEP_NOTIF_INFO, line_new)

            # Assume 4 bytes/number (float on cuda).
            total_input_size = abs(
                np.prod(input_size) * self.batch_size * 4. / (1024**2.))
            total_output_size = abs(2. * total_output * 4. /
                                    (1024**2.))  # x2 for gradients
            total_params_size = abs(total_params.numpy() * 4. / (1024**2.))
            total_size = total_params_size + total_output_size + total_input_size
            Notification(
                DEEP_NOTIF_INFO,
                '----------------------------------------------------------------'
            )
            Notification(DEEP_NOTIF_INFO, "Input size : %s" % self.input_size)
            Notification(DEEP_NOTIF_INFO, "Batch size : %s" % self.batch_size)
            for key, item in self.model_dict.items():
                Notification(DEEP_NOTIF_INFO, "%s : %s" % (key, item))
            Notification(
                DEEP_NOTIF_INFO,
                '----------------------------------------------------------------'
            )
            Notification(DEEP_NOTIF_INFO,
                         'Total params: {0:,}'.format(total_params))
            Notification(DEEP_NOTIF_INFO,
                         'Trainable params: {0:,}'.format(trainable_params))
            Notification(
                DEEP_NOTIF_INFO,
                'Non-trainable params: {0:,}'.format(total_params -
                                                     trainable_params))
            Notification(
                DEEP_NOTIF_INFO,
                '----------------------------------------------------------------'
            )
            Notification(DEEP_NOTIF_INFO,
                         "Input size (MB): %0.2f" % total_input_size)
            Notification(
                DEEP_NOTIF_INFO,
                "Forward/backward pass size (MB): %0.2f" % total_output_size)
            Notification(DEEP_NOTIF_INFO,
                         "Params size (MB): %0.2f" % total_params_size)
            Notification(DEEP_NOTIF_INFO,
                         "Estimated Total Size (MB): %0.2f" % total_size)
            Notification(
                DEEP_NOTIF_INFO,
                '----------------------------------------------------------------'
            )

    # Initialise the model
    model = Model(name=name,
                  origin=origin,
                  device=device,
                  device_ids=device_ids,
                  input_size=input_size,
                  batch_size=batch_size,
                  model_dict=kwargs,
                  **kwargs)

    if model_state_dict is not None:
        model_state_dict = dict(
            (k[7:], v) if k.startswith("module.") else (k, v)
            for k, v in model_state_dict.items())
        model.load_state_dict(model_state_dict)

    n_devices = torch.cuda.device_count() if device_ids is None else len(
        device_ids)
    if n_devices > 1:
        model = DataParallelModel(module=model)

    # Send to the appropriate device
    model.to(device)

    return model
Esempio n. 20
0
def load_optimizer(name,
                   module,
                   model,
                   kwargs,
                   param_groups=None,
                   verbose=True):

    pgs = param_groups
    if param_groups is not None:
        param_groups = make_param_groups(param_groups, model, verbose=verbose)

    optimizer, module = get_module(name=name,
                                   module=module,
                                   browse=DEEP_MODULE_OPTIMIZERS)

    class Optimizer(optimizer):
        def __init__(self, name, module, parameters, **kwargs):
            super(Optimizer, self).__init__(parameters, **kwargs)
            self.name = name
            self.module = module

        def summary(self):
            Notification(
                DEEP_NOTIF_INFO,
                '================================================================'
            )
            Notification(DEEP_NOTIF_INFO,
                         "OPTIMIZER : %s from %s" % (self.name, self.module))
            Notification(
                DEEP_NOTIF_INFO,
                '================================================================'
            )
            Notification(DEEP_NOTIF_INFO, "SORRY, NOT IMPLEMENTED YET")

    if param_groups is None:
        return Optimizer(name, module, model.parameters(), **kwargs.get())
    else:
        optimizer = Optimizer(name=name,
                              module=module,
                              parameters=param_groups[0]["params"],
                              **{
                                  **kwargs.get(),
                                  **param_groups[0]["kwargs"]
                              })
        for group in param_groups[1:]:
            optimizer.add_param_group({
                "params": group["params"],
                **{
                    **kwargs.get(),
                    **group["kwargs"]
                }
            })

        # Print summary of the parameter groups
        if verbose:
            kwargs = [{
                k: v
                for k, v in sorted(group.items()) if k != "params"
            } for group in optimizer.param_groups]
            for i, (param_group, pg) in enumerate(zip(param_groups, pgs)):
                Notification(
                    DEEP_NOTIF_INFO,
                    "Parameter Group %s : %s : %s sets : kwargs=%s" %
                    (str(i).ljust(3), ("%s" % pg.condition).ljust(
                        max([len(str(i.condition)) for i in pgs])),
                     str(len(param_group["params"])).rjust(4), kwargs[i]))
        return optimizer
Esempio n. 21
0
    def load_losses(self):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Load the losses

        PARAMETERS:
        -----------

        None

        RETURN:
        -------

        :return loss_functions->dict: The losses
        """
        losses = {}
        if self.config.losses.get():
            for key, config in self.config.losses.get().items():

                # Get the expected loss path (for notification purposes)
                if config.module is None:
                    loss_path = "%s : %s from default modules" % (key,
                                                                  config.name)
                else:
                    loss_path = "%s : %s from %s" % (key, config.name,
                                                     config.module)

                # Notify the user which loss is being collected and from where
                Notification(DEEP_NOTIF_INFO,
                             DEEP_MSG_LOSS_LOADING % loss_path)

                # Get the loss object
                loss, module = get_module(name=config.name,
                                          module=config.module,
                                          browse=DEEP_MODULE_LOSSES)
                method = loss(**config.kwargs.get())

                # Check the weight
                if self.config.losses.check("weight", key):
                    if get_corresponding_flag(
                            flag_list=[DEEP_DTYPE_INTEGER, DEEP_DTYPE_FLOAT],
                            info=get_int_or_float(config.weight),
                            fatal=False) is None:
                        Notification(
                            DEEP_NOTIF_FATAL,
                            "The loss function %s doesn't have a correct weight argument"
                            % key)
                else:
                    Notification(
                        DEEP_NOTIF_FATAL,
                        "The loss function %s doesn't have any weight argument"
                        % key)

                # Create the loss
                if isinstance(method, torch.nn.Module):
                    losses[str(key)] = Loss(name=str(key),
                                            weight=float(config.weight),
                                            loss=method)
                    Notification(
                        DEEP_NOTIF_SUCCESS,
                        DEEP_MSG_LOSS_LOADED % (key, config.name, module))
                else:
                    Notification(
                        DEEP_NOTIF_FATAL,
                        DEEP_MSG_LOSS_NOT_TORCH % (key, config.name, module))
            self.losses = Losses(losses)
        else:
            Notification(DEEP_NOTIF_INFO, DEEP_MSG_LOSS_NONE)