Пример #1
0
    def __init__(self,
                 name: str = "no__model_name",
                 save_directory: str = DEEP_PATH_SAVE_MODEL,
                 save_condition: int = DEEP_SAVE_CONDITION_AUTO,
                 save_method=DEEP_SAVE_NET_FORMAT_PYTORCH):

        self.save_method = save_method
        self.save_condition = save_condition
        self.directory = save_directory
        self.name = name
        self.best_overwatch_metric = None

        if self.save_method == DEEP_SAVE_NET_FORMAT_ONNX:
            self.extension = ".onnx"
        else:
            self.extension = ".model"

        if not os.path.isfile(self.directory):
            os.makedirs(self.directory, exist_ok=True)

        # Connect the save to the computation of the overwatched metric
        Thalamus().connect(receiver=self.is_saving_required,
                           event=DEEP_EVENT_OVERWATCH_METRIC_COMPUTED,
                           expected_arguments=["current_overwatch_metric"])
        Thalamus().connect(receiver=self.on_training_end,
                           event=DEEP_EVENT_ON_TRAINING_END,
                           expected_arguments=["model"])
        Thalamus().connect(receiver=self.save_model,
                           event=DEEP_EVENT_SAVE_MODEL,
                           expected_arguments=["model"])
Пример #2
0
    def __init__(self,
                 name: str = "no_model_name",
                 save_directory: str = "weights",
                 save_signal: Flag = DEEP_EVENT_ON_EPOCH_END,
                 method: Flag = DEEP_SAVE_FORMAT_PYTORCH,
                 overwrite: bool = False):
        self.name = name
        self.directory = save_directory
        self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                                  save_signal)
        self.method = get_corresponding_flag(DEEP_LIST_SAVE_FORMATS,
                                             method)  # Can be onnx or pt
        self.best_overwatch_metric = None
        self.training_loss = None
        self.model = None
        self.optimizer = None
        self.epoch_index = -1
        self.batch_index = -1
        self.validation_loss = None
        self.overwrite = overwrite
        self.inp = None

        # Set the extension
        if DEEP_SAVE_FORMAT_PYTORCH.corresponds(self.method):
            self.extension = DEEP_EXT_PYTORCH
        elif DEEP_SAVE_FORMAT_ONNX.corresponds(self.method):
            self.extension = DEEP_EXT_ONNX

        if not os.path.isfile(self.directory):
            os.makedirs(self.directory, exist_ok=True)

        # Connect the save to the computation of the overwatched metric
        Thalamus().connect(receiver=self.on_overwatch_metric_computed,
                           event=DEEP_EVENT_OVERWATCH_METRIC_COMPUTED,
                           expected_arguments=["current_overwatch_metric"])
        Thalamus().connect(receiver=self.on_training_end,
                           event=DEEP_EVENT_ON_TRAINING_END,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.save_model,
                           event=DEEP_EVENT_SAVE_MODEL,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.set_training_loss,
                           event=DEEP_EVENT_SEND_TRAINING_LOSS,
                           expected_arguments=["training_loss"])
        Thalamus().connect(receiver=self.set_save_params,
                           event=DEEP_EVENT_SEND_SAVE_PARAMS_FROM_TRAINER,
                           expected_arguments=[
                               "model", "optimizer", "epoch_index",
                               "validation_loss", "inp"
                           ])
Пример #3
0
 def __init__(self):
     Thalamus().connect(
         event=DEEP_EVENT_PRINT_TRAINING_EPOCH_END,
         expected_arguments=["losses", "total_loss", "metrics"],
         receiver=self.training_epoch_end)
     Thalamus().connect(
         event=DEEP_EVENT_PRINT_VALIDATION_EPOCH_END,
         expected_arguments=["losses", "total_loss", "metrics"],
         receiver=self.validation_epoch_end)
     Thalamus().connect(event=DEEP_EVENT_PRINT_TRAINING_BATCH_END,
                        expected_arguments=[
                            "losses", "total_loss", "metrics",
                            "minibatch_index", "num_minibatches"
                        ],
                        receiver=self.training_batch_end)
Пример #4
0
 def save_model(self):
     Thalamus().add_signal(
         signal=Signal(
             event=DEEP_EVENT_SAVE_MODEL,
             args={}
         )
     )
Пример #5
0
    def saving_required(self, saving_required: bool):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Signal to send the model to be saved if require
        NB : Contains a signal, cannot be static

        PARAMETERS:
        -----------

        :param saving_required: (bool): Whether saving the model is required or not

        RETURN:
        -------

        None
        """
        if saving_required is True:
            Thalamus().add_signal(
                signal=Signal(event=DEEP_EVENT_SAVE_MODEL, args={}))
Пример #6
0
    def __init__(self, config_dir):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize a Deeplodocus Brain

        PARAMETERS:
        -----------

        :param config_path->str: The config path

        RETURN:
        -------

        :return: None
        """
        self.__close_logs(force=True)
        Logo(version=__version__)
        FrontalLobe.__init__(self)  # Model Manager
        self.config_dir = config_dir
        self.visual_cortex = None
        time.sleep(0.5)  # Wait for the UI to respond
        self.config = None
        self._config = None
        self.load_config()
        self.set_device()
        Thalamus()  # Initialize the Signal Manager
Пример #7
0
    def __init__(
        self,
        model=None,
        optimizer=None,
        overwatch: OverWatch = None,
        save_directory: str = "weights",
        save_signal: Flag = DEEP_EVENT_EPOCH_END,
        method: Flag = DEEP_SAVE_FORMAT_PYTORCH,
        overwrite: bool = False,
    ):
        self.model = model
        self.optimizer = optimizer
        self.directory = save_directory
        self.overwatch = overwatch
        self.overwrite = overwrite
        self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                                  save_signal)
        self.method = get_corresponding_flag(DEEP_LIST_SAVE_FORMATS,
                                             method)  # Can be onnx or pt

        # Set the extension
        if DEEP_SAVE_FORMAT_PYTORCH.corresponds(self.method):
            self.extension = DEEP_EXT_PYTORCH
        elif DEEP_SAVE_FORMAT_ONNX.corresponds(self.method):
            self.extension = DEEP_EXT_ONNX

        # Connect to signals
        Thalamus().connect(receiver=self.save_model,
                           event=DEEP_EVENT_SAVE_MODEL,
                           expected_arguments=[])
Пример #8
0
 def send_epoch_end_signal(self, **kwargs):
     kwargs["epoch_index"] = self.epoch
     kwargs["loss"] = self.train_loss
     kwargs["losses"] = self.train_losses
     kwargs["metrics"] = self.train_metrics
     Thalamus().add_signal(
         signal=Signal(event=DEEP_EVENT_EPOCH_END, args=kwargs))
Пример #9
0
 def send_training_loss(self):
     Thalamus().add_signal(
         Signal(event=DEEP_EVENT_SEND_TRAINING_LOSS,
                args={
                    DEEP_LOG_VALIDATION.var_name:
                    self._loss_data[DEEP_LOG_VALIDATION.var_name]
                }))
Пример #10
0
    def send_save_params(self, inp=None) -> None:
        """
        AUTHORS:
        --------

        :author: Samuel Westlake
        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Send the saving parameters to the Saver

        PARAMETERS:
        -----------

        :param inp: The input size of the model (required for ONNX models)

        RETURN:
        -------

        :return: None
        """
        Thalamus().add_signal(
            Signal(event=DEEP_EVENT_SEND_SAVE_PARAMS_FROM_TRAINER,
                   args={
                       "model": self.model,
                       "optimizer": self.optimizer,
                       "epoch_index": self.epoch,
                       "validation_loss": self.validation_loss,
                       "inp": inp
                   }))
Пример #11
0
    def saving_required(self, saving_required: bool):
        """

        :param saving_required:
        :return:
        """

        if saving_required is True:
            Thalamus().add_signal(signal= Signal(event=DEEP_EVENT_SAVE_MODEL, args={"model": self.model}))
Пример #12
0
 def __init__(self,
              model=None,
              optimizer=None,
              overwatch: OverWatch = OverWatch(),
              save_signal: Flag = DEEP_SAVE_SIGNAL_AUTO,
              method: Flag = DEEP_SAVE_FORMAT_PYTORCH,
              overwrite: bool = False,
              enable_train_batches: bool = True,
              enable_train_epochs: bool = True,
              enable_validation: bool = True,
              history_directory: str = "history",
              weights_directory: str = "weights"):
     self.name = "Memory"
     save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL,
                                          save_signal)
     self.history = History(log_dir=history_directory,
                            save_signal=save_signal,
                            enable_train_batches=enable_train_batches,
                            enable_train_epochs=enable_train_epochs,
                            enable_validation=enable_validation)
     self.saver = Saver(model=model,
                        optimizer=optimizer,
                        overwatch=overwatch,
                        save_directory=weights_directory,
                        save_signal=save_signal,
                        method=method,
                        overwrite=overwrite)
     self._model = model
     self._optimizer = optimizer
     self._overwatch = overwatch
     # Connect to signals
     Thalamus().connect(receiver=self.on_batch_end,
                        event=DEEP_EVENT_BATCH_END,
                        expected_arguments=[
                            "batch_index", "num_batches", "epoch_index",
                            "loss", "losses", "metrics"
                        ])
     Thalamus().connect(
         receiver=self.on_epoch_end,
         event=DEEP_EVENT_EPOCH_END,
         expected_arguments=["epoch_index", "loss", "losses", "metrics"])
     Thalamus().connect(
         receiver=self.on_validation_end,
         event=DEEP_EVENT_VALIDATION_END,
         expected_arguments=["epoch_index", "loss", "losses", "metrics"])
     Thalamus().connect(receiver=self.on_train_start,
                        event=DEEP_EVENT_TRAINING_START,
                        expected_arguments=[])
     Thalamus().connect(receiver=self.on_train_end,
                        event=DEEP_EVENT_TRAINING_END,
                        expected_arguments=[])
     Thalamus().connect(receiver=self.send_training_loss,
                        event=DEEP_EVENT_REQUEST_TRAINING_LOSS,
                        expected_arguments=[])
Пример #13
0
    def __compute_overwatch_metric(self, num_minibatches_training,
                                   running_total_loss, running_losses,
                                   running_metrics, total_validation_loss,
                                   result_validation_losses,
                                   result_validation_metrics) -> None:

        # If the validation loss is None (No validation) we take the metric from the training as overwatch metric
        if total_validation_loss is None:
            data = dict([(TOTAL_LOSS,
                          running_total_loss / num_minibatches_training)] +
                        [(loss_name, value.item() / num_minibatches_training)
                         for (loss_name, value) in running_losses.items()] +
                        [(metric_name, value / num_minibatches_training)
                         for (metric_name, value) in running_metrics.items()])

            for key, value in data.items():
                if key == self.overwatch_metric.get_name():
                    self.overwatch_metric.set_value(value)
                    break
        else:
            data = dict([(TOTAL_LOSS, total_validation_loss)] +
                        [(loss_name, value.item())
                         for (loss_name,
                              value) in result_validation_losses.items()] +
                        [(metric_name, value / num_minibatches_training)
                         for (metric_name,
                              value) in result_validation_metrics.items()])

            for key, value in data.items():
                if key == self.overwatch_metric.get_name():
                    self.overwatch_metric.set_value(value)
                    break

        Thalamus().add_signal(
            Signal(event=DEEP_EVENT_OVERWATCH_METRIC_COMPUTED,
                   args={
                       "current_overwatch_metric":
                       copy.deepcopy(self.overwatch_metric)
                   }))
Пример #14
0
    def on_batch_end(self, minibatch_index: int, num_minibatches: int,
                     epoch_index: int, total_loss: int, result_losses: dict,
                     result_metrics: dict):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Called at the end of every batch

        PARAMETERS:
        -----------

        :param minibatch_index: int: Index of the current minibatch
        :param num_minibatches: int: Number of minibatches per epoch
        :param epoch_index: int: Index of the current epoch
        :param total_loss: int: The total loss
        :param result_losses: dict: List of resulting losses
        :param result_metrics: dict: List of resulting metrics

        RETURN:
        -------

        :return: None
        """
        # Save the running metrics
        self.running_total_loss = self.running_total_loss + total_loss
        self.running_losses = merge_sum_dict(self.running_losses,
                                             result_losses)
        self.running_metrics = merge_sum_dict(self.running_metrics,
                                              result_metrics)

        # If the user wants to print stats for each batch
        if DEEP_VERBOSE_BATCH.corresponds(self.verbose):

            # Print training loss and metrics on batch end
            Thalamus().add_signal(
                Signal(event=DEEP_EVENT_PRINT_TRAINING_BATCH_END,
                       args={
                           "losses": result_losses,
                           "total_loss": total_loss,
                           "metrics": result_metrics,
                           "num_minibatches": num_minibatches,
                           "minibatch_index": minibatch_index
                       }))

        # Save the data in memory
        if DEEP_MEMORIZE_BATCHES.corresponds(self.memorize):
            # Save the history in memory
            data = [datetime.datetime.now().strftime(TIME_FORMAT),
                    self.__time(),
                    epoch_index,
                    minibatch_index,
                    total_loss] + \
                    [value.item() for (loss_name, value) in result_losses.items()] + \
                    [value for (metric_name, value) in result_metrics.items()]
            self.train_batches_history.put(data)

        # Save the history after 10 batches
        if self.train_batches_history.qsize() > 10:
            self.save(only_batches=True)
Пример #15
0
 def send_validation_end_signal(**kwargs):
     Thalamus().add_signal(
         signal=Signal(event=DEEP_EVENT_VALIDATION_END, args=kwargs))
Пример #16
0
 def send_training_end_signal(**kwargs):
     Thalamus().add_signal(
         signal=Signal(event=DEEP_EVENT_TRAINING_END, args=kwargs))
Пример #17
0
    def __init__(
        self,
        metrics: dict,
        losses: dict,
        log_dir: str = DEEP_PATH_HISTORY,
        train_batches_filename: str = "history_batches_training.csv",
        train_epochs_filename: str = "history_epochs_training.csv",
        validation_filename: str = "history_validation.csv",
        verbose: int = DEEP_VERBOSE_BATCH,
        memorize: int = DEEP_MEMORIZE_BATCHES,
        save_condition:
        int = DEEP_SAVE_CONDITION_END_EPOCH,  # DEEP_SAVE_CONDITION_END_TRAINING to save at the end of training, DEEP_SAVE_CONDITION_END_EPOCH to save at the end of the epoch,
        overwatch_metric: OverWatchMetric = OverWatchMetric(
            name=TOTAL_LOSS, condition=DEEP_COMPARE_SMALLER),
    ):
        self.log_dir = log_dir
        self.verbose = verbose
        self.metrics = metrics
        self.losses = losses
        self.memorize = memorize
        self.save_condition = save_condition
        self.overwatch_metric = overwatch_metric

        # Running metrics
        self.running_total_loss = 0
        self.running_losses = {}
        self.running_metrics = {}

        self.train_batches_history = multiprocessing.Manager().Queue()
        self.train_epochs_history = multiprocessing.Manager().Queue()
        self.validation_history = multiprocessing.Manager().Queue()

        # Add headers to history files
        train_batches_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, BATCH, TOTAL_LOSS] +
            list(losses.keys()) + list(metrics.keys()))
        train_epochs_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] +
            list(losses.keys()) + list(metrics.keys()))
        validation_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] +
            list(losses.keys()) + list(metrics.keys()))

        self.__add_logs("history_train_batches", log_dir, ".csv",
                        train_batches_headers)
        self.__add_logs("history_train_epochs", log_dir, ".csv",
                        train_epochs_headers)
        self.__add_logs("history_validation", log_dir, ".csv",
                        validation_headers)

        self.start_time = 0
        self.paused = False

        # Filepaths
        self.log_dir = log_dir
        self.train_batches_filename = train_batches_filename
        self.train_epochs_filename = train_epochs_filename
        self.validation_filename = validation_filename

        # Load histories
        self.__load_histories()

        # Connect to signals
        Thalamus().connect(receiver=self.on_batch_end,
                           event=DEEP_EVENT_ON_BATCH_END)
        Thalamus().connect(receiver=self.on_epoch_end,
                           event=DEEP_EVENT_ON_EPOCH_END,
                           expected_arguments=[
                               "epoch_index", "num_epochs", "num_minibatches",
                               "total_validation_loss",
                               "result_validation_losses",
                               "result_validation_metrics",
                               "num_minibatches_validation"
                           ])
        Thalamus().connect(receiver=self.on_train_begin,
                           event=DEEP_EVENT_ON_TRAINING_START,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.on_train_end,
                           event=DEEP_EVENT_ON_TRAINING_END,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.on_epoch_start,
                           event=DEEP_EVENT_ON_EPOCH_START,
                           expected_arguments=["epoch_index", "num_epochs"])
Пример #18
0
 def send_batch_end_signal(**kwargs):
     Thalamus().add_signal(
         signal=Signal(event=DEEP_EVENT_BATCH_END, args=kwargs))
Пример #19
0
 def send_batch_start_signal(**kwargs):
     Thalamus().add_signal(
         signal=Signal(event=DEEP_EVENT_BATCH_START, args=kwargs))
Пример #20
0
    def save_model(self) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Save the model

        PARAMETERS:
        -----------

        RETURN:
        -------

        :return: None
        """
        # Set training_loss
        Thalamus().add_signal(
            Signal(event=DEEP_EVENT_REQUEST_TRAINING_LOSS, args=[]))

        # Set model and stuff
        Thalamus().add_signal(
            Signal(event=DEEP_EVENT_REQUEST_SAVE_PARAMS_FROM_TRAINER, args=[]))

        file_path = self.__get_file_path()

        # If we want to save to the pytorch format
        if DEEP_SAVE_FORMAT_PYTORCH.corresponds(self.method):
            # TODO: Finish try except statements here after testing...
            # try:
            torch.save(
                {
                    "model_state_dict": self.model.state_dict(),
                    "epoch": self.epoch_index,
                    "training_loss": self.training_loss,
                    "validation_loss": self.validation_loss,
                    "optimizer_state_dict": self.optimizer.state_dict()
                }, file_path)
            # except:
            #     Notification(DEEP_NOTIF_ERROR, "Error while saving the pytorch model and weights" )
            #     self.__handle_error_saving(model)

        # If we want to save to the ONNX format
        elif DEEP_SAVE_FORMAT_ONNX.corresponds(self.method):
            # TODO: and here. Also fix onnx export function
            Notification(DEEP_NOTIF_FATAL,
                         "Save as onnx format not implemented yet")
            # try:
            # torch.onnx._export(model, inp, file_path,
            #                    export_params=True,
            #                    verbose=True,
            #                    input_names=input_names,
            #                    output_names=output_names)
            # except:
            #     Notification(DEEP_NOTIF_ERROR, "Error while saving the ONNX model and weights" )
            #     self.__handle_error_saving(model)

        Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_MODEL_SAVED % file_path)
Пример #21
0
    def __init__(self,
                 model: nn.Module,
                 dataset: Dataset,
                 metrics: dict,
                 losses: dict,
                 optimizer,
                 num_epochs: int,
                 initial_epoch: int = 1,
                 batch_size: int = 4,
                 shuffle_method: Flag = DEEP_SHUFFLE_NONE,
                 num_workers: int = 4,
                 verbose: Flag = DEEP_VERBOSE_BATCH,
                 tester: Tester = None) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize a Trainer instance

        PARAMETERS:
        -----------

        :param model (torch.nn.Module): The model which has to be trained
        :param dataset (Dataset): The dataset to be trained on
        :param metrics (dict): The metrics to analyze
        :param losses (dict): The losses to use for the backpropagation
        :param optimizer: The optimizer to use for the backpropagation
        :param num_epochs (int): Number of epochs for the training
        :param initial_epoch (int): The index of the initial epoch
        :param batch_size (int): Size a minibatch
        :param shuffle_method (Flag): DEEP_SHUFFLE flag, method of shuffling to use
        :param num_workers (int): Number of processes / threads to use for data loading
        :param verbose (int): DEEP_VERBOSE flag, How verbose the Trainer is
        :param memorize (int): DEEP_MEMORIZE flag, what data to save
        :param save_condition (int): DEEP_SAVE flag, when to save the results
        :param tester (Tester): The tester to use for validation
        :param model_name (str): The name of the model

        RETURN:
        -------

        :return: None
        """
        # Initialize the GenericEvaluator par
        super().__init__(model=model,
                         dataset=dataset,
                         metrics=metrics,
                         losses=losses,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         verbose=verbose)
        self.optimizer = optimizer
        self.initial_epoch = initial_epoch
        self.epoch = None
        self.validation_loss = None
        self.num_epochs = num_epochs

        # Load shuffling method
        self.shuffle_method = get_corresponding_flag(DEEP_LIST_SHUFFLE,
                                                     shuffle_method,
                                                     fatal=False,
                                                     default=DEEP_SHUFFLE_NONE)

        if isinstance(tester, Tester):
            self.tester = tester  # Tester for validation
            self.tester.set_metrics(metrics=metrics)
            self.tester.set_losses(losses=losses)
        else:
            self.tester = None

        # Early stopping
        # self.stopping = Stopping(stopping_parameters)

        #
        # Connect signals
        #
        Thalamus().connect(receiver=self.saving_required,
                           event=DEEP_EVENT_SAVING_REQUIRED,
                           expected_arguments=["saving_required"])
        Thalamus().connect(receiver=self.send_save_params,
                           event=DEEP_EVENT_REQUEST_SAVE_PARAMS_FROM_TRAINER,
                           expected_arguments=[])
Пример #22
0
    def is_saving_required(self,
                           current_overwatch_metric: OverWatchMetric) -> bool:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Check if saving the model is required

        PARAMETERS:
        -----------

        :param current_overwatch_metric_value->float: The value of the metric to over watch

        RETURN:
        -------

        :return->bool: Whether the model should be saved or not
        """
        save = False

        # Do not save at the first epoch
        if self.best_overwatch_metric is None:
            self.best_overwatch_metric = current_overwatch_metric
            save = False

        # If  the new metric has to be smaller than the best one
        if current_overwatch_metric.get_condition() == DEEP_COMPARE_SMALLER:
            # If the model improved since last batch => Save
            if self.best_overwatch_metric.get_value(
            ) > current_overwatch_metric.get_value():
                self.best_overwatch_metric = current_overwatch_metric
                save = True

            # No improvement => Return False
            else:
                save = False

        # If the new metric has to be bigger than the best one (e.g. The accuracy of a classification)
        elif current_overwatch_metric.get_condition() == DEEP_COMPARE_BIGGER:
            # If the model improved since last batch => Save
            if self.best_overwatch_metric.get_value(
            ) < current_overwatch_metric.get_value():
                self.best_overwatch_metric = current_overwatch_metric
                save = True

            # No improvement => Return False
            else:
                save = False

        else:
            Notification(
                DEEP_NOTIF_FATAL,
                "The following saving condition does not exist : " +
                str("test"))

        Thalamus().add_signal(signal=Signal(event=DEEP_EVENT_SAVING_REQUIRED,
                                            args={"saving_required": save}))
Пример #23
0
    def __init__(self,
                 model: Module,
                 dataset: Dataset,
                 metrics: dict,
                 losses: dict,
                 optimizer,
                 num_epochs: int,
                 initial_epoch: int = 1,
                 batch_size: int = 4,
                 shuffle: int = DEEP_SHUFFLE_ALL,
                 num_workers: int = 4,
                 verbose: int=DEEP_VERBOSE_BATCH,
                 tester: Tester=None):
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Initialize a Trainer instance

        PARAMETERS:
        -----------

        :param model->torch.nn.Module: The model which has to be trained
        :param dataset->Dataset: The dataset to be trained on
        :param metrics->dict: The metrics to analyze
        :param losses->dict: The losses to use for the backpropagation
        :param optimizer: The optimizer to use for the backpropagation
        :param num_epochs->int: Number of epochs for the training
        :param initial_epoch->int: The index of the initial epoch
        :param batch_size->int: Size a minibatch
        :param shuffle->int: DEEP_SHUFFLE flag, method of shuffling to use
        :param num_workers->int: Number of processes / threads to use for data loading
        :param verbose->int: DEEP_VERBOSE flag, How verbose the Trainer is
        :param memorize->int: DEEP_MEMORIZE flag, what data to save
        :param save_condition->int: DEEP_SAVE flag, when to save the results
        :param stopping_parameters:
        :param tester->Tester: The tester to use for validation
        :param model_name->str: The name of the model

        RETURN:
        -------

        :return: None
        """
        # Initialize the GenericEvaluator par
        super().__init__(model=model,
                         dataset=dataset,
                         metrics=metrics,
                         losses=losses,
                         batch_size=batch_size,
                         num_workers=num_workers,
                         verbose=verbose)


        self.shuffle = shuffle
        self.optimizer = optimizer
        self.initial_epoch = initial_epoch
        self.num_epochs = num_epochs

        if isinstance(tester, Tester):
            self.tester = tester          # Tester for validation
            self.tester.set_metrics(metrics=metrics)
            self.tester.set_losses(losses=losses)
        else:
            self.tester = None

        # Early stopping
        #self.stopping = Stopping(stopping_parameters)


        Thalamus().connect(receiver=self.saving_required, event=DEEP_EVENT_SAVING_REQUIRED, expected_arguments=["saving_required"])
Пример #24
0
    def __compute_overwatch_metric(self, num_minibatches_training,
                                   running_total_loss, running_losses,
                                   running_metrics, total_validation_loss,
                                   result_validation_losses,
                                   result_validation_metrics) -> None:
        """
        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Compute the overwatch metric and send it to the saver

        PARAMETERS:
        -----------

        :param num_minibatches_training:
        :param running_total_loss:
        :param running_losses:
        :param running_metrics:
        :param total_validation_loss:
        :param result_validation_losses:
        :param result_validation_metrics:


        RETURN:
        -------

        :return:
        """

        # If the validation loss is None (No validation) we take the metric from the training as overwatch metric
        if total_validation_loss is None:
            data = dict([(TOTAL_LOSS,
                          running_total_loss / num_minibatches_training)] +
                        [(loss_name, value.item() / num_minibatches_training)
                         for (loss_name, value) in running_losses.items()] +
                        [(metric_name, value / num_minibatches_training)
                         for (metric_name, value) in running_metrics.items()])

            for key, value in data.items():
                if key == self.overwatch_metric.get_name():
                    self.overwatch_metric.set_value(value)
                    break
        else:
            data = dict([(TOTAL_LOSS, total_validation_loss)] +
                        [(loss_name, value.item())
                         for (loss_name,
                              value) in result_validation_losses.items()] +
                        [(metric_name, value / num_minibatches_training)
                         for (metric_name,
                              value) in result_validation_metrics.items()])

            for key, value in data.items():
                if key == self.overwatch_metric.get_name():
                    self.overwatch_metric.set_value(value)
                    break

        Thalamus().add_signal(
            Signal(event=DEEP_EVENT_OVERWATCH_METRIC_COMPUTED,
                   args={
                       "current_overwatch_metric":
                       copy.deepcopy(self.overwatch_metric)
                   }))
Пример #25
0
    def on_epoch_end(self, epoch_index: int, num_epochs: int,
                     num_minibatches: int, total_validation_loss: int,
                     result_validation_losses: dict,
                     result_validation_metrics: dict,
                     num_minibatches_validation: int):
        """
        AUTHORS:
        --------

        :author: Alix Leroy
        :author: Samuel Westlake

        DESCRIPTION:
        ------------

        Method for managing history at the end of each epoch

        PARAMETERS:
        -----------

        :param epoch_index: int: current epoch index
        :param num_epochs: int: total number of epoch
        :param num_minibatches: int: number of minibatches per epoch
        :param total_validation_loss:
        :param result_validation_losses:
        :param result_validation_metrics:
        :param num_minibatches_validation:

        RETURN:
        -------

        :return: None
        """
        # MANAGE TRAINING HISTORY
        if DEEP_VERBOSE_EPOCH.corresponds(
                self.verbose) or DEEP_VERBOSE_BATCH.corresponds(self.verbose):

            # Print the training loss and metrics on epoch end
            Thalamus().add_signal(
                Signal(event=DEEP_EVENT_PRINT_TRAINING_EPOCH_END,
                       args={
                           "losses": {
                               key: value / num_minibatches
                               for key, value in self.running_losses.items()
                           },
                           "total_loss":
                           self.running_total_loss / num_minibatches,
                           "metrics": {
                               key: value / num_minibatches
                               for key, value in self.running_metrics.items()
                           },
                       }))

        # If recording on batch or epoch
        if DEEP_MEMORIZE_BATCHES.corresponds(
                self.memorize) or DEEP_MEMORIZE_EPOCHS.corresponds(
                    self.memorize):
            data = [
                datetime.datetime.now().strftime(TIME_FORMAT),
                self.__time(),
                epoch_index,
                self.running_total_loss / num_minibatches
            ]\
                   + [value.item() / num_minibatches for (loss_name, value) in self.running_losses.items()]\
                   + [value / num_minibatches for (metric_name, value) in self.running_metrics.items()]
            self.train_epochs_history.put(data)

        self.running_total_loss = 0
        self.running_losses = {}
        self.running_metrics = {}

        # MANAGE VALIDATION HISTORY
        if total_validation_loss is not None:
            if DEEP_VERBOSE_EPOCH.corresponds(
                    self.verbose) or DEEP_VERBOSE_BATCH.corresponds(
                        self.verbose):

                # Print the validation loss and metrics on epoch end
                Thalamus().add_signal(
                    Signal(event=DEEP_EVENT_PRINT_VALIDATION_EPOCH_END,
                           args={
                               "losses": result_validation_losses,
                               "total_loss": total_validation_loss,
                               "metrics": result_validation_metrics,
                           }))

            if DEEP_MEMORIZE_BATCHES.corresponds(
                    self.memorize) or DEEP_MEMORIZE_EPOCHS.corresponds(
                        self.memorize):
                data = [
                    datetime.datetime.now().strftime(TIME_FORMAT),
                    self.__time(),
                    epoch_index,
                    total_validation_loss
                ] \
                    + [value.item() for (loss_name, value) in result_validation_losses.items()] \
                    + [value for (metric_name, value) in result_validation_metrics.items()]
                self.validation_history.put(data)

        if DEEP_SAVE_SIGNAL_AUTO.corresponds(self.save_signal):
            self.__compute_overwatch_metric(
                num_minibatches_training=num_minibatches,
                running_total_loss=self.running_total_loss,
                running_losses=self.running_losses,
                running_metrics=self.running_metrics,
                total_validation_loss=total_validation_loss,
                result_validation_losses=result_validation_losses,
                result_validation_metrics=result_validation_metrics)
        elif DEEP_SAVE_SIGNAL_END_EPOCH.corresponds(self.save_signal):
            Thalamus().add_signal(Signal(event=DEEP_EVENT_SAVE_MODEL, args={}))

        Notification(DEEP_NOTIF_SUCCESS, EPOCH_END % (epoch_index, num_epochs))
        self.save()
Пример #26
0
    def __init__(self,
                 metrics: dict,
                 losses: dict,
                 log_dir: str = "history",
                 train_batches_filename: str = "history_batches_training.csv",
                 train_epochs_filename: str = "history_epochs_training.csv",
                 validation_filename: str = "history_validation.csv",
                 verbose: Flag = DEEP_VERBOSE_BATCH,
                 memorize: Flag = DEEP_MEMORIZE_BATCHES,
                 save_signal: Flag = DEEP_SAVE_SIGNAL_END_EPOCH,
                 overwatch_metric: OverWatchMetric = OverWatchMetric(
                     name=TOTAL_LOSS, condition=DEEP_SAVE_CONDITION_LESS)):

        self.log_dir = log_dir
        self.verbose = verbose
        self.metrics = metrics
        self.losses = losses
        self.memorize = get_corresponding_flag(
            [DEEP_MEMORIZE_BATCHES, DEEP_MEMORIZE_EPOCHS], info=memorize)
        self.save_signal = save_signal
        self.overwatch_metric = overwatch_metric

        # Running metrics
        self.running_total_loss = 0
        self.running_losses = {}
        self.running_metrics = {}

        self.train_batches_history = multiprocessing.Manager().Queue()
        self.train_epochs_history = multiprocessing.Manager().Queue()
        self.validation_history = multiprocessing.Manager().Queue()

        # Add headers to history files
        train_batches_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, BATCH, TOTAL_LOSS] +
            list(vars(losses).keys()) + list(vars(metrics).keys()))
        train_epochs_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] +
            list(vars(losses).keys()) + list(vars(metrics).keys()))
        validation_headers = ",".join(
            [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] +
            list(vars(losses).keys()) + list(vars(metrics).keys()))

        # Create the history files
        self.__add_logs("history_train_batches", log_dir, ".csv",
                        train_batches_headers)
        self.__add_logs("history_train_epochs", log_dir, ".csv",
                        train_epochs_headers)
        self.__add_logs("history_validation", log_dir, ".csv",
                        validation_headers)

        self.start_time = 0
        self.paused = False

        # Filepaths
        self.log_dir = log_dir
        self.train_batches_filename = train_batches_filename
        self.train_epochs_filename = train_epochs_filename
        self.validation_filename = validation_filename

        # Load histories
        self.__load_histories()

        # Connect to signals
        Thalamus().connect(receiver=self.on_batch_end,
                           event=DEEP_EVENT_ON_BATCH_END,
                           expected_arguments=[
                               "minibatch_index", "num_minibatches",
                               "epoch_index", "total_loss", "result_losses",
                               "result_metrics"
                           ])
        Thalamus().connect(receiver=self.on_epoch_end,
                           event=DEEP_EVENT_ON_EPOCH_END,
                           expected_arguments=[
                               "epoch_index", "num_epochs", "num_minibatches",
                               "total_validation_loss",
                               "result_validation_losses",
                               "result_validation_metrics",
                               "num_minibatches_validation"
                           ])
        Thalamus().connect(receiver=self.on_train_begin,
                           event=DEEP_EVENT_ON_TRAINING_START,
                           expected_arguments=[])
        Thalamus().connect(receiver=self.on_train_end,
                           event=DEEP_EVENT_ON_TRAINING_END,
                           expected_arguments=[])

        Thalamus().connect(receiver=self.on_epoch_start,
                           event=DEEP_EVENT_ON_EPOCH_START,
                           expected_arguments=["epoch_index", "num_epochs"])
        Thalamus().connect(receiver=self.send_training_loss,
                           event=DEEP_EVENT_REQUEST_TRAINING_LOSS,
                           expected_arguments=[])
Пример #27
0
    def __train(self, first_training: bool = True) -> None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Loop over the dataset to train the network

        PARAMETERS:
        -----------

        :param first_training (bool): Whether more epochs have been required after initial training or not

        RETURN:
        -------

        :return: None
        """
        if first_training is True:
            Thalamus().add_signal(
                signal=Signal(event=DEEP_EVENT_ON_TRAINING_START, args={}))

        for self.epoch in range(self.initial_epoch + 1, self.num_epochs + 1):

            Thalamus().add_signal(
                signal=Signal(event=DEEP_EVENT_ON_EPOCH_START,
                              args={
                                  "epoch_index": self.epoch,
                                  "num_epochs": self.num_epochs
                              }))

            # Shuffle the data if required
            if self.shuffle_method is not None:
                self.dataset.shuffle(self.shuffle_method)

            # Put model into train mode for the start of the epoch
            self.model.train()

            for minibatch_index, minibatch in enumerate(self.dataloader, 0):

                # Clean the given data
                inputs, labels, additional_data = self.clean_single_element_list(
                    minibatch)

                # zero the parameter gradients
                self.optimizer.zero_grad()

                # Set the data to the corresponding device
                inputs = self.to_device(inputs, self.model.device)
                labels = self.to_device(labels, self.model.device)
                additional_data = self.to_device(additional_data,
                                                 self.model.device)

                # Infer the output of the batch
                try:
                    outputs = self.model(*inputs)
                except RuntimeError as e:
                    Notification(DEEP_NOTIF_FATAL,
                                 "RuntimeError : %s" % str(e))
                except TypeError as e:
                    Notification(DEEP_NOTIF_FATAL, "TypeError : %s" % str(e))

                # Compute losses and metrics
                result_losses = self.compute_metrics(self.losses, inputs,
                                                     outputs, labels,
                                                     additional_data)
                result_metrics = self.compute_metrics(self.metrics, inputs,
                                                      outputs, labels,
                                                      additional_data)

                # Add weights to losses
                result_losses = dict_utils.apply_weight(
                    result_losses, vars(self.losses))

                # Sum all the result of the losses
                total_loss = sum_dict(result_losses)

                # Accumulates the gradient (by addition) for each parameter
                total_loss.backward()

                # Performs a parameter update based on the current gradient (stored in .grad attribute of a parameter)
                # and the update rule
                self.optimizer.step()

                # Detach the tensors from the network
                outputs, total_loss, result_losses, result_metrics = self.detach(
                    outputs=outputs,
                    total_loss=total_loss,
                    result_losses=result_losses,
                    result_metrics=result_metrics)

                # Send signal batch end
                Thalamus().add_signal(
                    Signal(event=DEEP_EVENT_ON_BATCH_END,
                           args={
                               "minibatch_index": minibatch_index + 1,
                               "num_minibatches": self.num_minibatches,
                               "epoch_index": self.epoch,
                               "total_loss": total_loss.item(),
                               "result_losses": result_losses,
                               "result_metrics": result_metrics
                           }))

            # Reset the dataset (transforms cache)
            self.dataset.reset()

            # Evaluate the model
            self.validation_loss, result_validation_losses, result_validation_metrics = self.__evaluate_epoch(
            )

            if self.tester is not None:
                num_minibatches_validation = self.tester.get_num_minibatches()
            else:
                num_minibatches_validation = None

            # Send signal epoch end
            Thalamus().add_signal(
                Signal(event=DEEP_EVENT_ON_EPOCH_END,
                       args={
                           "epoch_index":
                           self.epoch,
                           "num_epochs":
                           self.num_epochs,
                           "model":
                           weakref.ref(self.model),
                           "num_minibatches":
                           self.num_minibatches,
                           "total_validation_loss":
                           self.validation_loss,
                           "result_validation_losses":
                           result_validation_losses,
                           "result_validation_metrics":
                           result_validation_metrics,
                           "num_minibatches_validation":
                           num_minibatches_validation,
                       }))

        # Send signal end training
        Thalamus().add_signal(
            Signal(event=DEEP_EVENT_ON_TRAINING_END,
                   args={"model": self.model}))
Пример #28
0
 def send_training_loss(self):
     Thalamus().add_signal(
         Signal(event=DEEP_EVENT_SEND_TRAINING_LOSS,
                args={"training_loss": self.running_total_loss}))
Пример #29
0
    def __train(self, first_training=True)->None:
        """
        AUTHORS:
        --------

        :author: Alix Leroy

        DESCRIPTION:
        ------------

        Loop over the dataset to train the network

        PARAMETERS:
        -----------

        :param first_training->bool: Whether more epochs have been required after initial training or not

        RETURN:
        -------

        :return: None
        """

        if first_training is True:
            Thalamus().add_signal(signal=Signal(event=DEEP_EVENT_ON_TRAINING_START, args={}))
        else:
            self.callbacks.unpause()

        for epoch in range(self.initial_epoch+1, self.num_epochs+1):  # loop over the dataset multiple times

            Thalamus().add_signal(signal=Signal(event=DEEP_EVENT_ON_EPOCH_START, args={"epoch_index": epoch,
                                                                                       "num_epochs": self.num_epochs}))

            for minibatch_index, minibatch in enumerate(self.dataloader, 0):

                # Clean the given data
                inputs, labels, additional_data = self.clean_single_element_list(minibatch)

                # zero the parameter gradients
                self.optimizer.zero_grad()

                # Infer the output of the batch
                outputs = self.model(*inputs)

                # Compute losses and metrics
                result_losses = self.compute_metrics(self.losses, inputs, outputs, labels, additional_data)
                result_metrics = self.compute_metrics(self.metrics, inputs, outputs, labels, additional_data)

                # Add weights to losses
                result_losses = apply_weight(result_losses, self.losses)

                # Sum all the result of the losses
                total_loss = sum_dict(result_losses)

                # Accumulates the gradient (by addition) for each parameter
                total_loss.backward()

                # Performs a parameter update based on the current gradient (stored in .grad attribute of a parameter)
                # and the update rule
                self.optimizer.step()

                outputs, total_loss, result_losses, result_metrics = self.detach(outputs=outputs,
                                                                                 total_loss=total_loss,
                                                                                 result_losses=result_losses,
                                                                                 result_metrics=result_metrics)

                # Send signal batch end
                Thalamus().add_signal(Signal(event= DEEP_EVENT_ON_BATCH_END,
                                             args={"minibatch_index": minibatch_index+1,
                                                   "num_minibatches": self.num_minibatches,
                                                   "epoch_index": epoch,
                                                   "total_loss": total_loss.item(),
                                                   "result_losses": result_losses,
                                                   "result_metrics": result_metrics
                                                   }))

            # Shuffle the data if required
            if self.shuffle is not None:
                self.dataset.shuffle(self.shuffle)

            # Reset the dataset (transforms cache)
            self.dataset.reset()

            # Evaluate the model
            total_validation_loss, result_validation_losses, result_validation_metrics = self.__evaluate_epoch()

            # Send signal epoch end
            Thalamus().add_signal(Signal(event=DEEP_EVENT_ON_EPOCH_END,
                                         args={"epoch_index": epoch,
                                                "num_epochs" : self.num_epochs,
                                                "model" : self.model,
                                                "num_minibatches" : self.num_minibatches,
                                                "total_validation_loss" : total_validation_loss.item(),
                                                "result_validation_losses" : result_validation_losses,
                                                "result_validation_metrics" : result_validation_metrics,
                                                "num_minibatches_validation" : self.tester.get_num_minibatches()
                                               }))


        # Send signal end training
        Thalamus().add_signal(Signal(event=DEEP_EVENT_ON_TRAINING_END,
                                     args={"model" : self.model}))

        # Pause callbacks which compute time
        self.callbacks.pause()