Exemplo n.º 1
0
    def __init__(
        self,
        model=None,
        optimizer=None,
        overwatch: OverWatch = None,
        save_directory: str = "weights",
        save_signal: Flag = DEEP_EVENT_EPOCH_END,
        method: Flag = DEEP_SAVE_FORMAT_PYTORCH,
        overwrite: bool = False,
    ):
        self.model = model
        self.optimizer = optimizer
        self.directory = save_directory
        self.overwatch = overwatch
        self.overwrite = overwrite
        self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                                  save_signal)
        self.method = get_corresponding_flag(DEEP_LIST_SAVE_FORMATS,
                                             method)  # Can be onnx or pt

        # Set the extension
        if DEEP_SAVE_FORMAT_PYTORCH.corresponds(self.method):
            self.extension = DEEP_EXT_PYTORCH
        elif DEEP_SAVE_FORMAT_ONNX.corresponds(self.method):
            self.extension = DEEP_EXT_ONNX

        # Connect to signals
        Thalamus().connect(receiver=self.save_model,
                           event=DEEP_EVENT_SAVE_MODEL,
                           expected_arguments=[])
Exemplo n.º 2
0
    def __init__(
            self,
            metric: str = DEEP_LOG_TOTAL_LOSS,
            condition: Union[Flag, None] = DEEP_SAVE_CONDITION_LESS,
            dataset: Union[Flag, None] = DEEP_DATASET_VAL
    ):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize the OverWatchMetric instance

        PARAMETERS:
        -----------
        :param name (str): The name of the metric to over watch
        :param condition (Flag):
        """
        self.metric = metric
        self.dataset = DEEP_DATASET_VAL if dataset is None \
            else get_corresponding_flag([DEEP_DATASET_TRAIN, DEEP_DATASET_VAL], dataset)
        self.current_best = None
        self._condition = get_corresponding_flag(DEEP_LIST_SAVE_CONDITIONS, condition)
        self._is_better = None
        self.set_is_better()
Exemplo n.º 3
0
    def __init__(self,
                 name: str = "no_model_name",
                 save_directory: str = "weights",
                 save_signal: Flag = DEEP_EVENT_ON_EPOCH_END,
                 method: Flag = DEEP_SAVE_FORMAT_PYTORCH,
                 overwrite: bool = False):
        self.name = name
        self.directory = save_directory
        self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                                  save_signal)
        self.method = get_corresponding_flag(DEEP_LIST_SAVE_FORMATS,
                                             method)  # Can be onnx or pt
        self.best_overwatch_metric = None
        self.training_loss = None
        self.model = None
        self.optimizer = None
        self.epoch_index = -1
        self.batch_index = -1
        self.validation_loss = None
        self.overwrite = overwrite
        self.inp = None

        # Set the extension
        if DEEP_SAVE_FORMAT_PYTORCH.corresponds(self.method):
            self.extension = DEEP_EXT_PYTORCH
        elif DEEP_SAVE_FORMAT_ONNX.corresponds(self.method):
            self.extension = DEEP_EXT_ONNX

        if not os.path.isfile(self.directory):
            os.makedirs(self.directory, exist_ok=True)

        # Connect the save to the computation of the overwatched metric
        Thalamus().connect(receiver=self.on_overwatch_metric_computed,
                           event=DEEP_EVENT_OVERWATCH_METRIC_COMPUTED,
                           expected_arguments=["current_overwatch_metric"])
        Thalamus().connect(receiver=self.on_training_end,
                           event=DEEP_EVENT_ON_TRAINING_END,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.save_model,
                           event=DEEP_EVENT_SAVE_MODEL,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.set_training_loss,
                           event=DEEP_EVENT_SEND_TRAINING_LOSS,
                           expected_arguments=["training_loss"])
        Thalamus().connect(receiver=self.set_save_params,
                           event=DEEP_EVENT_SEND_SAVE_PARAMS_FROM_TRAINER,
                           expected_arguments=[
                               "model", "optimizer", "epoch_index",
                               "validation_loss", "inp"
                           ])
Exemplo n.º 4
0
    def __check_load_as(self, load_as: Union[str, None, Flag]) -> Flag:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check if the load_as argument is correct and return the corresponding Flag

        PARAMETERS:
        -----------

        :param load_as(Union[str, None, Flag]): The load_as argument given in the config file

        RETURN:
        -------

        :return (Flag): The corresponding DEEP_LOAD_AS flag.
        """
        if load_as is None:
            return None
        else:
            return get_corresponding_flag(flag_list=DEEP_LIST_LOAD_AS,
                                          info=load_as)
Exemplo n.º 5
0
    def __check_data_type(self, data_type: Union[str, int, Flag, None]):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check the data type
        If the data type given is None we try to estimate it (errors can occur with complex types)
        Else we directly get the data type given by the user

        PARAMETERS:
        -----------

        :param data_type (Union[str, int, None]: The data type in a raw format given by the user

        RETURN:
        -------

        :return data_type(Flag): The data type of the entry
        """

        if data_type is None:
            instance_example, _, _ = self.__getitem__(index=0)
            # Automatically check the data type
            data_type = self.__estimate_data_type(instance_example)
        else:
            data_type = get_corresponding_flag(flag_list=DEEP_LIST_DTYPE,
                                               info=data_type)
        return data_type
Exemplo n.º 6
0
 def __init__(self,
              dataset: Dataset,
              model,
              transform_manager,
              losses: Losses,
              metrics: Union[Metrics, None] = None,
              batch_size: int = 32,
              num_workers: int = 1,
              shuffle: Flag = DEEP_SHUFFLE_NONE,
              name: str = "Inferer"):
     self.dataset = dataset
     self.model = model
     self.transform_manager = transform_manager
     self.losses = losses
     self.metrics = Metrics() if metrics is None else metrics
     self.batch_size = batch_size
     self.num_workers = num_workers
     self.name = name
     self.shuffle = get_corresponding_flag(DEEP_LIST_SHUFFLE,
                                           shuffle,
                                           fatal=False,
                                           default=DEEP_SHUFFLE_NONE)
     self.dataloader = DataLoader(dataset=self.dataset,
                                  batch_size=self.batch_size,
                                  shuffle=False,
                                  num_workers=self.num_workers)
Exemplo n.º 7
0
    def __check_convert_to(
            convert_to: Union[str, None, Flag]) -> Optional[Flag]:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check if the convert_to argument is correct and return the corresponding Flag

        PARAMETERS:
        -----------

        :param convert_to(Union[str, None, Flag]): The convert_to argument given in the config file

        RETURN:
        -------

        :return (Flag): The corresponding DEEP_DTYPE_AS flag.
        """
        if convert_to is None:
            return None
        else:
            return get_corresponding_flag(flag_list=DEEP_LIST_DTYPE,
                                          info=convert_to)
Exemplo n.º 8
0
    def __convert_pointer_type(self, pointer_type: Union[str, int,
                                                         Flag]) -> Flag:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Convert the pointer type to the actual entry flag

        PARAMETERS:
        -----------

        :param pointer_type (str): The pointer type the user wants

        RETURN:
        -------

        :return flag(Flag): The corresponding flag of entry type
        """

        flag = get_corresponding_flag(flag_list=DEEP_LIST_POINTER_ENTRY,
                                      info=pointer_type,
                                      fatal=True)
        if flag is None:
            Notification(
                DEEP_NOTIF_FATAL,
                "The type of the following transformer's pointer does not exist :' %s'. "
                "Please check the documentation." % str(self.name))
        else:
            return flag
Exemplo n.º 9
0
    def set_cv_library(self, cv_library: Flag) -> None:
        """
         AUTHORS:
         --------

         :author: Samuel Westlake
         :author: Alix Leroy

         DESCRIPTION:
         ------------

         Set self.cv_library to the given value and import the corresponding cv library

         PARAMETERS:
         -----------

         :param cv_library: (Flag): The flag of the computer vision library selected

         RETURN:
         -------

         None
         """
        self.cv_library = get_corresponding_flag(flag_list=DEEP_LIST_CV_LIB,
                                                 info=cv_library)
        self.__import_cv_library(cv_library=cv_library)
Exemplo n.º 10
0
 def __check_args(self) -> dict:
     if inspect.isfunction(self.method):
         args = inspect.getfullargspec(self.method).args
     else:
         args = inspect.getfullargspec(self.method.forward).args
         args.remove("self")
     args_dict = {}
     for arg in args:
         if self.is_custom:
             entry_flag = get_corresponding_flag(DEEP_LIST_ENTRY,
                                                 arg,
                                                 fatal=False)
             if entry_flag is None:
                 Notification(DEEP_NOTIF_FATAL,
                              DEEP_MSG_LOSS_UNEXPECTED_ARG % arg)
             else:
                 args_dict[entry_flag] = arg
         else:
             if arg in ("input", "x", "inputs"):
                 args_dict[DEEP_ENTRY_OUTPUT] = arg
             elif arg in ("y", "y_hat", "label", "labels", "target",
                          "targets"):
                 args_dict[DEEP_ENTRY_LABEL] = arg
             else:
                 Notification(DEEP_NOTIF_FATAL,
                              DEEP_MSG_LOSS_UNEXPECTED_ARG % arg)
     self.check_essential_args(args_dict)
     return {
         arg_name: entry_flag
         for entry_flag, arg_name in args_dict.items()
     }
Exemplo n.º 11
0
 def forward(self,
             flag,
             outputs,
             labels,
             inputs=None,
             additional_data=None,
             model=None):
     flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False)
     metrics = {}
     for metric_name in self.names:
         metric_value = vars(self)[metric_name].forward(
             outputs=outputs,
             labels=labels,
             inputs=inputs,
             additional_data=additional_data,
             model=model)
         if isinstance(metric_value, dict):
             metrics = {**metrics, **metric_value}
             for k in metric_value.keys():
                 if k not in vars(self):
                     vars(self)[k] = vars(self)[metric_name]
         elif isinstance(metric_value, torch.Tensor):
             metrics[metric_name] = metric_value.item()
         else:
             metrics[metric_name] = metric_value
     self.update_values(self.values[flag.name.lower()], metrics)
     return metrics
Exemplo n.º 12
0
    def __init__(
        self,
        name: str,
        type: Flag,
        entries: List[Namespace],
        num_instances: int,
        transform_manager: Optional[TransformManager],
        use_raw_data: bool = True,
    ):

        entries = list_namespace2list_dict(entries)
        self.name = name  # Name of the Dataset
        self.type = get_corresponding_flag(DEEP_LIST_DATASET, type)

        # List containing the Entry instances
        self.entries = []
        self.__generate_entries(entries=entries)

        # List containing the PipelineEntry instances
        self.pipeline_entries = []
        self.__generate_pipeline_entries(entries=entries)

        self.number_raw_instances = self.__calculate_number_raw_instances(
        )  # Number of raw instances
        self.length = self.__compute_length(
            desired_length=num_instances,
            num_raw_instances=self.number_raw_instances
        )  # Length of the Dataset

        self.item_order = np.arange(self.length)  # List of items indices
        self.use_raw_data = use_raw_data  # Whether we want to use raw data or only transformed data
        self.transform_manager = transform_manager
Exemplo n.º 13
0
 def reduce(self, flag):
     flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False)
     losses = {
         loss_name: (sum(values) / len(values)).item()
         for loss_name, values in self.values[flag.name.lower()].items()
     }
     loss = sum([value for _, value in losses.items()])
     return loss, losses
Exemplo n.º 14
0
 def __init__(self,
              model=None,
              optimizer=None,
              overwatch: OverWatch = OverWatch(),
              save_signal: Flag = DEEP_SAVE_SIGNAL_AUTO,
              method: Flag = DEEP_SAVE_FORMAT_PYTORCH,
              overwrite: bool = False,
              enable_train_batches: bool = True,
              enable_train_epochs: bool = True,
              enable_validation: bool = True,
              history_directory: str = "history",
              weights_directory: str = "weights"):
     self.name = "Memory"
     save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                          save_signal)
     self.history = History(log_dir=history_directory,
                            save_signal=save_signal,
                            enable_train_batches=enable_train_batches,
                            enable_train_epochs=enable_train_epochs,
                            enable_validation=enable_validation)
     self.saver = Saver(model=model,
                        optimizer=optimizer,
                        overwatch=overwatch,
                        save_directory=weights_directory,
                        save_signal=save_signal,
                        method=method,
                        overwrite=overwrite)
     self._model = model
     self._optimizer = optimizer
     self._overwatch = overwatch
     # Connect to signals
     Thalamus().connect(receiver=self.on_batch_end,
                        event=DEEP_EVENT_BATCH_END,
                        expected_arguments=[
                            "batch_index", "num_batches", "epoch_index",
                            "loss", "losses", "metrics"
                        ])
     Thalamus().connect(
         receiver=self.on_epoch_end,
         event=DEEP_EVENT_EPOCH_END,
         expected_arguments=["epoch_index", "loss", "losses", "metrics"])
     Thalamus().connect(
         receiver=self.on_validation_end,
         event=DEEP_EVENT_VALIDATION_END,
         expected_arguments=["epoch_index", "loss", "losses", "metrics"])
     Thalamus().connect(receiver=self.on_train_start,
                        event=DEEP_EVENT_TRAINING_START,
                        expected_arguments=[])
     Thalamus().connect(receiver=self.on_train_end,
                        event=DEEP_EVENT_TRAINING_END,
                        expected_arguments=[])
     Thalamus().connect(receiver=self.send_training_loss,
                        event=DEEP_EVENT_REQUEST_TRAINING_LOSS,
                        expected_arguments=[])
Exemplo n.º 15
0
    def __init__(self,
                 index: int,
                 name: str,
                 etype: Flag,
                 dataset: weakref,
                 load_as: str,
                 enable_cache: bool = False,
                 cv_library: Union[str, None, Flag] = DEEP_LIB_OPENCV):

        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize an entry for the Dataset

        PARAMETERS:
        -----------

        :param dataset(weakref): Weak reference to the dataset
        :param

        RETURN:
        -------

        :return: None
        """
        self.index = index  # ID of the entry
        self.name = name
        self.etype = get_corresponding_flag(DEEP_LIST_ENTRY, etype)
        self.dataset = dataset  # Weak reference to the dataset

        # Loader
        self.loader = Loader(
            data_entry=weakref.ref(self),
            load_as=load_as,
            cv_library=cv_library
        )

        self.sources = list()  # List of sources into the entry
        self.enable_cache = enable_cache  # Enable cache memory for pointer

        # Cache Memory for pointers
        if self.enable_cache is True:
            self.cache_memory = list()
        else:
            self.cache_memory = None
        self.num_instances = None
Exemplo n.º 16
0
 def reset(self, flag=None):
     flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False)
     if flag is None:
         self.values = {flag.name.lower(): {} for flag in DEEP_LIST_DATASET}
     else:
         self.values[flag.name.lower()] = {}
     # Call reset method for each loss and metric
     for m in self.names:
         with contextlib.suppress(AttributeError):
             vars(self)[m].method.reset()
             Notification(
                 DEEP_NOTIF_INFO, "Reset %s : %s" %
                 ("Loss" if self.__class__ is Losses else "Metric", m))
Exemplo n.º 17
0
 def reduce(self, flag):
     flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False)
     reduced_metrics = {}
     for metric_name, values in self.values[flag.name.lower()].items():
         if self.__dict__[metric_name].ignore_value is not None:
             values = list(
                 filter(
                     lambda i: i != self.__dict__[metric_name].ignore_value,
                     values))
         try:
             reduced_metrics[metric_name] = self.__dict__[
                 metric_name].reduce_method(values)
         except ZeroDivisionError:
             reduced_metrics[metric_name] = float("inf")
     return reduced_metrics
Exemplo n.º 18
0
    def __init__(self,
                 model: Module,
                 dataset: Dataset,
                 metrics: dict,
                 losses: dict,
                 batch_size: int = 4,
                 num_workers: int = 4,
                 verbose: Flag = DEEP_VERBOSE_BATCH):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize a GenericEvaluator instance

        PARAMETERS:
        -----------

        :param model->torch.nn.Module: The model to infer
        :param dataset->Dataset: A dataset
        :param batch_size->int: The number of instances per batch
        :param num_workers->int: The number of processes / threads used for data loading
        :param verbose->int: How verbose the class is


        RETURN:
        -------

        :return: None
        """

        #
        super().__init__(model=model,
                         dataset=dataset,
                         batch_size=batch_size,
                         num_workers=num_workers)
        self.verbose = get_corresponding_flag(DEEP_LIST_VERBOSE, verbose)
        self.verbose = verbose
        self.metrics = metrics
        self.losses = losses
Exemplo n.º 19
0
 def __init__(self,
              log_dir: str = "history",
              train_batches_filename: str = "history_train_batches.csv",
              train_epochs_filename: str = "history_train_epochs.csv",
              validation_filename: str = "history_validation.csv",
              save_signal: Flag = DEEP_SAVE_SIGNAL_END_EPOCH,
              write_interval: int = 10,
              enable_train_batches: bool = True,
              enable_train_epochs: bool = True,
              enable_validation: bool = True,
              overwrite: bool = None):
     self.log_dir = log_dir
     self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                               save_signal)
     self.write_interval = write_interval
     self.overwrite = overwrite
     self.file_paths = {
         flag.var_name: "/".join((log_dir, file_name))
         for flag, file_name in zip(DEEP_LIST_LOG_HISTORY, (
             train_batches_filename, train_epochs_filename,
             validation_filename))
     }
     self.enabled = {
         flag.var_name: enabled
         for flag, enabled in zip(DEEP_LIST_LOG_HISTORY, (
             enable_train_batches, enable_train_epochs, enable_validation))
     }
     self.headers = {
         DEEP_LOG_TRAIN_BATCHES.var_name:
         [flag.name for flag in DEEP_LIST_HISTORY_HEADER],
         DEEP_LOG_TRAIN_EPOCHS.var_name: [
             flag.name for flag in DEEP_LIST_HISTORY_HEADER
             if not flag.corresponds(DEEP_LOG_BATCH)
         ],
         DEEP_LOG_VALIDATION.var_name: [
             flag.name for flag in DEEP_LIST_HISTORY_HEADER
             if not flag.corresponds(DEEP_LOG_BATCH)
         ],
     }
     self._training_start = None
     self._batch_data = {}
     self._loss_data = {item: None for item in (TRAINING, VALIDATION)}
     self.init_files()
Exemplo n.º 20
0
    def __init__(
            self,
            # History
            losses: dict,
            metrics: dict,
            model_name: str = generate_random_alphanumeric(size=10),
            verbose: Flag = DEEP_VERBOSE_BATCH,
            memorize: Flag = DEEP_MEMORIZE_BATCHES,
            history_directory: str = "history",
            overwatch_metric: OverWatchMetric = OverWatchMetric(
                name=TOTAL_LOSS, condition=DEEP_SAVE_CONDITION_LESS),
            # Saver
            save_signal: Flag = DEEP_SAVE_SIGNAL_AUTO,
            method: Flag = DEEP_SAVE_FORMAT_PYTORCH,
            overwrite: bool = False,
            save_model_directory: str = "weights"):

        save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                             info=save_signal,
                                             default=DEEP_SAVE_SIGNAL_AUTO)
        #
        # HISTORY
        #

        self.__initialize_history(name=model_name,
                                  metrics=metrics,
                                  losses=losses,
                                  log_dir=history_directory,
                                  verbose=verbose,
                                  memorize=memorize,
                                  save_signal=save_signal,
                                  overwatch_metric=overwatch_metric)

        #
        # SAVER
        #

        self.__initialize_saver(name=model_name,
                                save_directory=save_model_directory,
                                save_signal=save_signal,
                                method=method,
                                overwrite=overwrite)
Exemplo n.º 21
0
 def __check_args(self):
     if inspect.isfunction(self.method):
         args = inspect.getfullargspec(self.method).args
     else:
         args = inspect.getfullargspec(self.method.forward).args
         args.remove("self")
     args_dict = {}
     for arg_name in args:
         entry_flag = get_corresponding_flag(DEEP_LIST_ENTRY,
                                             arg_name,
                                             fatal=False)
         if entry_flag is None:
             Notification(DEEP_NOTIF_FATAL,
                          DEEP_MSG_METRIC_UNEXPECTED_ARG % arg_name)
         else:
             args_dict[entry_flag] = arg_name
     self.__check_essential_args(args_dict)
     return {
         arg_name: entry_flag
         for entry_flag, arg_name in args_dict.items()
     }
Exemplo n.º 22
0
    def __check_load_as(self, load_as: Union[str, int, Flag, None]) -> Flag:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check the data type
        If the data type given is None we try to estimate it (errors can occur with complex types)
        Else we directly get the data type given by the user

        PARAMETERS:
        -----------

        :param load_as (Union[str, int, None]): The data type in a raw format given by the user

        RETURN:
        -------

        :return load_as(Flag): The data type of the entry
        """

        if load_as is None:
            # Get an instance
            instance_example, is_loaded, _ = self.data_entry(
            ).__get_first_item()

            if is_loaded is True:
                load_as = None
            else:
                # Automatically check the data type
                load_as = self.__estimate_load_as(instance_example)
        else:
            load_as = get_corresponding_flag(flag_list=DEEP_LIST_LOAD_AS,
                                             info=load_as)
        return load_as
Exemplo n.º 23
0
 def forward(self,
             flag,
             model,
             outputs,
             labels,
             inputs=None,
             additional_data=None):
     flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False)
     losses = {}
     for loss_name in self.names:
         losses[loss_name] = self.__dict__[loss_name].forward(
             model=model,
             outputs=outputs,
             labels=labels,
             inputs=inputs,
             additional_data=additional_data)
     self.update_values(self.values[flag.name.lower()], losses)
     loss = sum([value for _, value in losses.items()])
     losses = {
         loss_name: value.item()
         for loss_name, value in losses.items()
     }
     return loss, losses
Exemplo n.º 24
0
    def __check_entry_type(entry_type: Union[str, int, Flag]) -> Flag:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check the entry type

        PARAMETERS:
        -----------

        :param entry_type (Union[str, int, Flag]): The raw entry type

        RETURN:
        -------

        :return entry_type(Flag): The entry type
        """
        return get_corresponding_flag(flag_list=DEEP_LIST_ENTRY, info=entry_type)
Exemplo n.º 25
0
    def __init__(self, name: str = TOTAL_LOSS, condition: Union[Flag, int, str, None] = DEEP_SAVE_CONDITION_LESS):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize the OverWatchMetric instance

        PARAMETERS:
        -----------
        :param name (str): The name of the metric to over watch
        :param condition (Flag):
        """
        self.name = name
        self.value = 0.0
        self.condition = get_corresponding_flag(flag_list=DEEP_LIST_SAVE_CONDITIONS,
                                                info=condition,
                                                fatal=False,
                                                default=DEEP_SAVE_CONDITION_LESS)
Exemplo n.º 26
0
    def __init__(self,
                 metrics: dict,
                 losses: dict,
                 log_dir: str = "history",
                 train_batches_filename: str = "history_batches_training.csv",
                 train_epochs_filename: str = "history_epochs_training.csv",
                 validation_filename: str = "history_validation.csv",
                 verbose: Flag = DEEP_VERBOSE_BATCH,
                 memorize: Flag = DEEP_MEMORIZE_BATCHES,
                 save_signal: Flag = DEEP_SAVE_SIGNAL_END_EPOCH,
                 overwatch_metric: OverWatchMetric = OverWatchMetric(
                     name=TOTAL_LOSS, condition=DEEP_SAVE_CONDITION_LESS)):

        self.log_dir = log_dir
        self.verbose = verbose
        self.metrics = metrics
        self.losses = losses
        self.memorize = get_corresponding_flag(
            [DEEP_MEMORIZE_BATCHES, DEEP_MEMORIZE_EPOCHS], info=memorize)
        self.save_signal = save_signal
        self.overwatch_metric = overwatch_metric

        # Running metrics
        self.running_total_loss = 0
        self.running_losses = {}
        self.running_metrics = {}

        self.train_batches_history = multiprocessing.Manager().Queue()
        self.train_epochs_history = multiprocessing.Manager().Queue()
        self.validation_history = multiprocessing.Manager().Queue()

        # Add headers to history files
        train_batches_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, BATCH, TOTAL_LOSS] +
            list(vars(losses).keys()) + list(vars(metrics).keys()))
        train_epochs_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] +
            list(vars(losses).keys()) + list(vars(metrics).keys()))
        validation_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] +
            list(vars(losses).keys()) + list(vars(metrics).keys()))

        # Create the history files
        self.__add_logs("history_train_batches", log_dir, ".csv",
                        train_batches_headers)
        self.__add_logs("history_train_epochs", log_dir, ".csv",
                        train_epochs_headers)
        self.__add_logs("history_validation", log_dir, ".csv",
                        validation_headers)

        self.start_time = 0
        self.paused = False

        # Filepaths
        self.log_dir = log_dir
        self.train_batches_filename = train_batches_filename
        self.train_epochs_filename = train_epochs_filename
        self.validation_filename = validation_filename

        # Load histories
        self.__load_histories()

        # Connect to signals
        Thalamus().connect(receiver=self.on_batch_end,
                           event=DEEP_EVENT_ON_BATCH_END,
                           expected_arguments=[
                               "minibatch_index", "num_minibatches",
                               "epoch_index", "total_loss", "result_losses",
                               "result_metrics"
                           ])
        Thalamus().connect(receiver=self.on_epoch_end,
                           event=DEEP_EVENT_ON_EPOCH_END,
                           expected_arguments=[
                               "epoch_index", "num_epochs", "num_minibatches",
                               "total_validation_loss",
                               "result_validation_losses",
                               "result_validation_metrics",
                               "num_minibatches_validation"
                           ])
        Thalamus().connect(receiver=self.on_train_begin,
                           event=DEEP_EVENT_ON_TRAINING_START,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.on_train_end,
                           event=DEEP_EVENT_ON_TRAINING_END,
                           expected_arguments=[])

        Thalamus().connect(receiver=self.on_epoch_start,
                           event=DEEP_EVENT_ON_EPOCH_START,
                           expected_arguments=["epoch_index", "num_epochs"])
        Thalamus().connect(receiver=self.send_training_loss,
                           event=DEEP_EVENT_REQUEST_TRAINING_LOSS,
                           expected_arguments=[])
Exemplo n.º 27
0
    def __create_transformer(self, config_entry):
        """
        CONTRIBUTORS:
        -------------

        Creator : Alix Leroy

        DESCRIPTION:
        ------------

        Create the adequate transformer

        PARAMETERS:
        -----------

        :param config: The transformer config
        :param pointer-> bool : Whether or not the transformer points to another transformer

        RETURN:
        -------

        :return transformer: The created transformer
        """
        transformer = None

        # NONE
        if config_entry is None:
            transformer = NoTransformer()

        # POINTER
        elif self.__is_pointer(config_entry) is True:
            transformer = Pointer(
                config_entry)  # Generic Transformer as a pointer

        # TRANSFORMER
        else:
            config = Namespace(config_entry)

            # Check if a method is given by the user
            if config.check("method", None) is False:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The following transformer does not have any method specified : "
                    + str(config_entry))

            # Get the corresponding flag
            flag = get_corresponding_flag(flag_list=DEEP_LIST_TRANSFORMERS,
                                          info=config.method,
                                          fatal=False)

            #
            # Create the corresponding Transformer
            #
            try:
                # SEQUENTIAL
                if DEEP_TRANSFORMER_SEQUENTIAL.corresponds(flag):
                    transformer = Sequential(**config.get(ignore="method"))
                # ONE OF
                elif DEEP_TRANSFORMER_ONE_OF.corresponds(flag):
                    transformer = OneOf(**config.get(ignore="method"))
                # SOME OF
                elif DEEP_TRANSFORMER_SOME_OF.corresponds(flag):
                    transformer = SomeOf(**config.get(ignore="method"))
                # If the method does not exist
                else:
                    Notification(
                        DEEP_NOTIF_FATAL,
                        "Unknown transformer method specified in %s : %s" %
                        (config_entry, config.method),
                        solutions=[
                            "Ensure a valid transformer method is specified in %s"
                            % config_entry
                        ])
            except TypeError as e:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "TypeError when loading transformer : %s : %s" %
                    (config_entry, e),
                    solutions=["Check the syntax of %s" % config_entry])
        return transformer
Exemplo n.º 28
0
 def reduce(self, new_value):
     self._reduce = get_corresponding_flag(DEEP_LIST_REDUCE, new_value)
     self.__set_reduce_method()
Exemplo n.º 29
0
    def __create_transformer(self, config_entry):
        """
        CONTRIBUTORS:
        -------------

        Creator : Alix Leroy

        DESCRIPTION:
        ------------

        Create the adequate transformer

        PARAMETERS:
        -----------

        :param config: The transformer config
        :param pointer-> bool : Whether or not the transformer points to another transformer

        RETURN:
        -------

        :return transformer: The created transformer
        """
        transformer = None

        # NONE
        if config_entry is None:
            transformer = NoTransformer()

        # POINTER
        elif self.__is_pointer(config_entry) is True:
            transformer = Pointer(
                config_entry)  # Generic Transformer as a pointer

        # TRANSFORMER
        else:
            config = Namespace(config_entry)

            # Check if a method is given by the user
            if config.check("method", None) is False:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The following transformer does not have any method specified : "
                    + str(config_entry))

            # Get the corresponding flag
            flag = get_corresponding_flag(flag_list=DEEP_LIST_TRANSFORMERS,
                                          info=config.method,
                                          fatal=False)

            # Remove the method from the config
            delattr(config, 'method')

            #
            # Create the corresponding Transformer
            #

            # SEQUENTIAL
            if DEEP_TRANSFORMER_SEQUENTIAL.corresponds(flag):
                transformer = Sequential(**config.get())
            # ONE OF
            elif DEEP_TRANSFORMER_ONE_OF.corresponds(flag):
                transformer = OneOf(**config.get())
            # SOME OF
            elif DEEP_TRANSFORMER_SOME_OF.corresponds(flag):
                SomeOf(**config.get())
            # If the method does not exist
            else:
                Notification(
                    DEEP_NOTIF_FATAL,
                    "The following transformation method does not exist : " +
                    str(config.method))

        return transformer
Exemplo n.º 30
0
    def __init__(self,
                 model: nn.Module,
                 dataset: Dataset,
                 metrics: dict,
                 losses: dict,
                 optimizer,
                 num_epochs: int,
                 initial_epoch: int = 1,
                 batch_size: int = 4,
                 shuffle_method: Flag = DEEP_SHUFFLE_NONE,
                 num_workers: int = 4,
                 verbose: Flag = DEEP_VERBOSE_BATCH,
                 tester: Tester = None) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize a Trainer instance

        PARAMETERS:
        -----------

        :param model (torch.nn.Module): The model which has to be trained
        :param dataset (Dataset): The dataset to be trained on
        :param metrics (dict): The metrics to analyze
        :param losses (dict): The losses to use for the backpropagation
        :param optimizer: The optimizer to use for the backpropagation
        :param num_epochs (int): Number of epochs for the training
        :param initial_epoch (int): The index of the initial epoch
        :param batch_size (int): Size a minibatch
        :param shuffle_method (Flag): DEEP_SHUFFLE flag, method of shuffling to use
        :param num_workers (int): Number of processes / threads to use for data loading
        :param verbose (int): DEEP_VERBOSE flag, How verbose the Trainer is
        :param memorize (int): DEEP_MEMORIZE flag, what data to save
        :param save_condition (int): DEEP_SAVE flag, when to save the results
        :param tester (Tester): The tester to use for validation
        :param model_name (str): The name of the model

        RETURN:
        -------

        :return: None
        """
        # Initialize the GenericEvaluator par
        super().__init__(model=model,
                         dataset=dataset,
                         metrics=metrics,
                         losses=losses,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         verbose=verbose)
        self.optimizer = optimizer
        self.initial_epoch = initial_epoch
        self.epoch = None
        self.validation_loss = None
        self.num_epochs = num_epochs

        # Load shuffling method
        self.shuffle_method = get_corresponding_flag(DEEP_LIST_SHUFFLE,
                                                     shuffle_method,
                                                     fatal=False,
                                                     default=DEEP_SHUFFLE_NONE)

        if isinstance(tester, Tester):
            self.tester = tester  # Tester for validation
            self.tester.set_metrics(metrics=metrics)
            self.tester.set_losses(losses=losses)
        else:
            self.tester = None

        # Early stopping
        # self.stopping = Stopping(stopping_parameters)

        #
        # Connect signals
        #
        Thalamus().connect(receiver=self.saving_required,
                           event=DEEP_EVENT_SAVING_REQUIRED,
                           expected_arguments=["saving_required"])
        Thalamus().connect(receiver=self.send_save_params,
                           event=DEEP_EVENT_REQUEST_SAVE_PARAMS_FROM_TRAINER,
                           expected_arguments=[])