def __init__( self, model=None, optimizer=None, overwatch: OverWatch = None, save_directory: str = "weights", save_signal: Flag = DEEP_EVENT_EPOCH_END, method: Flag = DEEP_SAVE_FORMAT_PYTORCH, overwrite: bool = False, ): self.model = model self.optimizer = optimizer self.directory = save_directory self.overwatch = overwatch self.overwrite = overwrite self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL, save_signal) self.method = get_corresponding_flag(DEEP_LIST_SAVE_FORMATS, method) # Can be onnx or pt # Set the extension if DEEP_SAVE_FORMAT_PYTORCH.corresponds(self.method): self.extension = DEEP_EXT_PYTORCH elif DEEP_SAVE_FORMAT_ONNX.corresponds(self.method): self.extension = DEEP_EXT_ONNX # Connect to signals Thalamus().connect(receiver=self.save_model, event=DEEP_EVENT_SAVE_MODEL, expected_arguments=[])
def __init__( self, metric: str = DEEP_LOG_TOTAL_LOSS, condition: Union[Flag, None] = DEEP_SAVE_CONDITION_LESS, dataset: Union[Flag, None] = DEEP_DATASET_VAL ): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Initialize the OverWatchMetric instance PARAMETERS: ----------- :param name (str): The name of the metric to over watch :param condition (Flag): """ self.metric = metric self.dataset = DEEP_DATASET_VAL if dataset is None \ else get_corresponding_flag([DEEP_DATASET_TRAIN, DEEP_DATASET_VAL], dataset) self.current_best = None self._condition = get_corresponding_flag(DEEP_LIST_SAVE_CONDITIONS, condition) self._is_better = None self.set_is_better()
def __init__(self, name: str = "no_model_name", save_directory: str = "weights", save_signal: Flag = DEEP_EVENT_ON_EPOCH_END, method: Flag = DEEP_SAVE_FORMAT_PYTORCH, overwrite: bool = False): self.name = name self.directory = save_directory self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL, save_signal) self.method = get_corresponding_flag(DEEP_LIST_SAVE_FORMATS, method) # Can be onnx or pt self.best_overwatch_metric = None self.training_loss = None self.model = None self.optimizer = None self.epoch_index = -1 self.batch_index = -1 self.validation_loss = None self.overwrite = overwrite self.inp = None # Set the extension if DEEP_SAVE_FORMAT_PYTORCH.corresponds(self.method): self.extension = DEEP_EXT_PYTORCH elif DEEP_SAVE_FORMAT_ONNX.corresponds(self.method): self.extension = DEEP_EXT_ONNX if not os.path.isfile(self.directory): os.makedirs(self.directory, exist_ok=True) # Connect the save to the computation of the overwatched metric Thalamus().connect(receiver=self.on_overwatch_metric_computed, event=DEEP_EVENT_OVERWATCH_METRIC_COMPUTED, expected_arguments=["current_overwatch_metric"]) Thalamus().connect(receiver=self.on_training_end, event=DEEP_EVENT_ON_TRAINING_END, expected_arguments=[]) Thalamus().connect(receiver=self.save_model, event=DEEP_EVENT_SAVE_MODEL, expected_arguments=[]) Thalamus().connect(receiver=self.set_training_loss, event=DEEP_EVENT_SEND_TRAINING_LOSS, expected_arguments=["training_loss"]) Thalamus().connect(receiver=self.set_save_params, event=DEEP_EVENT_SEND_SAVE_PARAMS_FROM_TRAINER, expected_arguments=[ "model", "optimizer", "epoch_index", "validation_loss", "inp" ])
def __check_load_as(self, load_as: Union[str, None, Flag]) -> Flag: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check if the load_as argument is correct and return the corresponding Flag PARAMETERS: ----------- :param load_as(Union[str, None, Flag]): The load_as argument given in the config file RETURN: ------- :return (Flag): The corresponding DEEP_LOAD_AS flag. """ if load_as is None: return None else: return get_corresponding_flag(flag_list=DEEP_LIST_LOAD_AS, info=load_as)
def __check_data_type(self, data_type: Union[str, int, Flag, None]): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check the data type If the data type given is None we try to estimate it (errors can occur with complex types) Else we directly get the data type given by the user PARAMETERS: ----------- :param data_type (Union[str, int, None]: The data type in a raw format given by the user RETURN: ------- :return data_type(Flag): The data type of the entry """ if data_type is None: instance_example, _, _ = self.__getitem__(index=0) # Automatically check the data type data_type = self.__estimate_data_type(instance_example) else: data_type = get_corresponding_flag(flag_list=DEEP_LIST_DTYPE, info=data_type) return data_type
def __init__(self, dataset: Dataset, model, transform_manager, losses: Losses, metrics: Union[Metrics, None] = None, batch_size: int = 32, num_workers: int = 1, shuffle: Flag = DEEP_SHUFFLE_NONE, name: str = "Inferer"): self.dataset = dataset self.model = model self.transform_manager = transform_manager self.losses = losses self.metrics = Metrics() if metrics is None else metrics self.batch_size = batch_size self.num_workers = num_workers self.name = name self.shuffle = get_corresponding_flag(DEEP_LIST_SHUFFLE, shuffle, fatal=False, default=DEEP_SHUFFLE_NONE) self.dataloader = DataLoader(dataset=self.dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
def __check_convert_to( convert_to: Union[str, None, Flag]) -> Optional[Flag]: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check if the convert_to argument is correct and return the corresponding Flag PARAMETERS: ----------- :param convert_to(Union[str, None, Flag]): The convert_to argument given in the config file RETURN: ------- :return (Flag): The corresponding DEEP_DTYPE_AS flag. """ if convert_to is None: return None else: return get_corresponding_flag(flag_list=DEEP_LIST_DTYPE, info=convert_to)
def __convert_pointer_type(self, pointer_type: Union[str, int, Flag]) -> Flag: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Convert the pointer type to the actual entry flag PARAMETERS: ----------- :param pointer_type (str): The pointer type the user wants RETURN: ------- :return flag(Flag): The corresponding flag of entry type """ flag = get_corresponding_flag(flag_list=DEEP_LIST_POINTER_ENTRY, info=pointer_type, fatal=True) if flag is None: Notification( DEEP_NOTIF_FATAL, "The type of the following transformer's pointer does not exist :' %s'. " "Please check the documentation." % str(self.name)) else: return flag
def set_cv_library(self, cv_library: Flag) -> None: """ AUTHORS: -------- :author: Samuel Westlake :author: Alix Leroy DESCRIPTION: ------------ Set self.cv_library to the given value and import the corresponding cv library PARAMETERS: ----------- :param cv_library: (Flag): The flag of the computer vision library selected RETURN: ------- None """ self.cv_library = get_corresponding_flag(flag_list=DEEP_LIST_CV_LIB, info=cv_library) self.__import_cv_library(cv_library=cv_library)
def __check_args(self) -> dict: if inspect.isfunction(self.method): args = inspect.getfullargspec(self.method).args else: args = inspect.getfullargspec(self.method.forward).args args.remove("self") args_dict = {} for arg in args: if self.is_custom: entry_flag = get_corresponding_flag(DEEP_LIST_ENTRY, arg, fatal=False) if entry_flag is None: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_LOSS_UNEXPECTED_ARG % arg) else: args_dict[entry_flag] = arg else: if arg in ("input", "x", "inputs"): args_dict[DEEP_ENTRY_OUTPUT] = arg elif arg in ("y", "y_hat", "label", "labels", "target", "targets"): args_dict[DEEP_ENTRY_LABEL] = arg else: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_LOSS_UNEXPECTED_ARG % arg) self.check_essential_args(args_dict) return { arg_name: entry_flag for entry_flag, arg_name in args_dict.items() }
def forward(self, flag, outputs, labels, inputs=None, additional_data=None, model=None): flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False) metrics = {} for metric_name in self.names: metric_value = vars(self)[metric_name].forward( outputs=outputs, labels=labels, inputs=inputs, additional_data=additional_data, model=model) if isinstance(metric_value, dict): metrics = {**metrics, **metric_value} for k in metric_value.keys(): if k not in vars(self): vars(self)[k] = vars(self)[metric_name] elif isinstance(metric_value, torch.Tensor): metrics[metric_name] = metric_value.item() else: metrics[metric_name] = metric_value self.update_values(self.values[flag.name.lower()], metrics) return metrics
def __init__( self, name: str, type: Flag, entries: List[Namespace], num_instances: int, transform_manager: Optional[TransformManager], use_raw_data: bool = True, ): entries = list_namespace2list_dict(entries) self.name = name # Name of the Dataset self.type = get_corresponding_flag(DEEP_LIST_DATASET, type) # List containing the Entry instances self.entries = [] self.__generate_entries(entries=entries) # List containing the PipelineEntry instances self.pipeline_entries = [] self.__generate_pipeline_entries(entries=entries) self.number_raw_instances = self.__calculate_number_raw_instances( ) # Number of raw instances self.length = self.__compute_length( desired_length=num_instances, num_raw_instances=self.number_raw_instances ) # Length of the Dataset self.item_order = np.arange(self.length) # List of items indices self.use_raw_data = use_raw_data # Whether we want to use raw data or only transformed data self.transform_manager = transform_manager
def reduce(self, flag): flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False) losses = { loss_name: (sum(values) / len(values)).item() for loss_name, values in self.values[flag.name.lower()].items() } loss = sum([value for _, value in losses.items()]) return loss, losses
def __init__(self, model=None, optimizer=None, overwatch: OverWatch = OverWatch(), save_signal: Flag = DEEP_SAVE_SIGNAL_AUTO, method: Flag = DEEP_SAVE_FORMAT_PYTORCH, overwrite: bool = False, enable_train_batches: bool = True, enable_train_epochs: bool = True, enable_validation: bool = True, history_directory: str = "history", weights_directory: str = "weights"): self.name = "Memory" save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL, save_signal) self.history = History(log_dir=history_directory, save_signal=save_signal, enable_train_batches=enable_train_batches, enable_train_epochs=enable_train_epochs, enable_validation=enable_validation) self.saver = Saver(model=model, optimizer=optimizer, overwatch=overwatch, save_directory=weights_directory, save_signal=save_signal, method=method, overwrite=overwrite) self._model = model self._optimizer = optimizer self._overwatch = overwatch # Connect to signals Thalamus().connect(receiver=self.on_batch_end, event=DEEP_EVENT_BATCH_END, expected_arguments=[ "batch_index", "num_batches", "epoch_index", "loss", "losses", "metrics" ]) Thalamus().connect( receiver=self.on_epoch_end, event=DEEP_EVENT_EPOCH_END, expected_arguments=["epoch_index", "loss", "losses", "metrics"]) Thalamus().connect( receiver=self.on_validation_end, event=DEEP_EVENT_VALIDATION_END, expected_arguments=["epoch_index", "loss", "losses", "metrics"]) Thalamus().connect(receiver=self.on_train_start, event=DEEP_EVENT_TRAINING_START, expected_arguments=[]) Thalamus().connect(receiver=self.on_train_end, event=DEEP_EVENT_TRAINING_END, expected_arguments=[]) Thalamus().connect(receiver=self.send_training_loss, event=DEEP_EVENT_REQUEST_TRAINING_LOSS, expected_arguments=[])
def __init__(self, index: int, name: str, etype: Flag, dataset: weakref, load_as: str, enable_cache: bool = False, cv_library: Union[str, None, Flag] = DEEP_LIB_OPENCV): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Initialize an entry for the Dataset PARAMETERS: ----------- :param dataset(weakref): Weak reference to the dataset :param RETURN: ------- :return: None """ self.index = index # ID of the entry self.name = name self.etype = get_corresponding_flag(DEEP_LIST_ENTRY, etype) self.dataset = dataset # Weak reference to the dataset # Loader self.loader = Loader( data_entry=weakref.ref(self), load_as=load_as, cv_library=cv_library ) self.sources = list() # List of sources into the entry self.enable_cache = enable_cache # Enable cache memory for pointer # Cache Memory for pointers if self.enable_cache is True: self.cache_memory = list() else: self.cache_memory = None self.num_instances = None
def reset(self, flag=None): flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False) if flag is None: self.values = {flag.name.lower(): {} for flag in DEEP_LIST_DATASET} else: self.values[flag.name.lower()] = {} # Call reset method for each loss and metric for m in self.names: with contextlib.suppress(AttributeError): vars(self)[m].method.reset() Notification( DEEP_NOTIF_INFO, "Reset %s : %s" % ("Loss" if self.__class__ is Losses else "Metric", m))
def reduce(self, flag): flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False) reduced_metrics = {} for metric_name, values in self.values[flag.name.lower()].items(): if self.__dict__[metric_name].ignore_value is not None: values = list( filter( lambda i: i != self.__dict__[metric_name].ignore_value, values)) try: reduced_metrics[metric_name] = self.__dict__[ metric_name].reduce_method(values) except ZeroDivisionError: reduced_metrics[metric_name] = float("inf") return reduced_metrics
def __init__(self, model: Module, dataset: Dataset, metrics: dict, losses: dict, batch_size: int = 4, num_workers: int = 4, verbose: Flag = DEEP_VERBOSE_BATCH): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Initialize a GenericEvaluator instance PARAMETERS: ----------- :param model->torch.nn.Module: The model to infer :param dataset->Dataset: A dataset :param batch_size->int: The number of instances per batch :param num_workers->int: The number of processes / threads used for data loading :param verbose->int: How verbose the class is RETURN: ------- :return: None """ # super().__init__(model=model, dataset=dataset, batch_size=batch_size, num_workers=num_workers) self.verbose = get_corresponding_flag(DEEP_LIST_VERBOSE, verbose) self.verbose = verbose self.metrics = metrics self.losses = losses
def __init__(self, log_dir: str = "history", train_batches_filename: str = "history_train_batches.csv", train_epochs_filename: str = "history_train_epochs.csv", validation_filename: str = "history_validation.csv", save_signal: Flag = DEEP_SAVE_SIGNAL_END_EPOCH, write_interval: int = 10, enable_train_batches: bool = True, enable_train_epochs: bool = True, enable_validation: bool = True, overwrite: bool = None): self.log_dir = log_dir self.save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL, save_signal) self.write_interval = write_interval self.overwrite = overwrite self.file_paths = { flag.var_name: "/".join((log_dir, file_name)) for flag, file_name in zip(DEEP_LIST_LOG_HISTORY, ( train_batches_filename, train_epochs_filename, validation_filename)) } self.enabled = { flag.var_name: enabled for flag, enabled in zip(DEEP_LIST_LOG_HISTORY, ( enable_train_batches, enable_train_epochs, enable_validation)) } self.headers = { DEEP_LOG_TRAIN_BATCHES.var_name: [flag.name for flag in DEEP_LIST_HISTORY_HEADER], DEEP_LOG_TRAIN_EPOCHS.var_name: [ flag.name for flag in DEEP_LIST_HISTORY_HEADER if not flag.corresponds(DEEP_LOG_BATCH) ], DEEP_LOG_VALIDATION.var_name: [ flag.name for flag in DEEP_LIST_HISTORY_HEADER if not flag.corresponds(DEEP_LOG_BATCH) ], } self._training_start = None self._batch_data = {} self._loss_data = {item: None for item in (TRAINING, VALIDATION)} self.init_files()
def __init__( self, # History losses: dict, metrics: dict, model_name: str = generate_random_alphanumeric(size=10), verbose: Flag = DEEP_VERBOSE_BATCH, memorize: Flag = DEEP_MEMORIZE_BATCHES, history_directory: str = "history", overwatch_metric: OverWatchMetric = OverWatchMetric( name=TOTAL_LOSS, condition=DEEP_SAVE_CONDITION_LESS), # Saver save_signal: Flag = DEEP_SAVE_SIGNAL_AUTO, method: Flag = DEEP_SAVE_FORMAT_PYTORCH, overwrite: bool = False, save_model_directory: str = "weights"): save_signal = get_corresponding_flag(DEEP_LIST_SAVE_SIGNAL, info=save_signal, default=DEEP_SAVE_SIGNAL_AUTO) # # HISTORY # self.__initialize_history(name=model_name, metrics=metrics, losses=losses, log_dir=history_directory, verbose=verbose, memorize=memorize, save_signal=save_signal, overwatch_metric=overwatch_metric) # # SAVER # self.__initialize_saver(name=model_name, save_directory=save_model_directory, save_signal=save_signal, method=method, overwrite=overwrite)
def __check_args(self): if inspect.isfunction(self.method): args = inspect.getfullargspec(self.method).args else: args = inspect.getfullargspec(self.method.forward).args args.remove("self") args_dict = {} for arg_name in args: entry_flag = get_corresponding_flag(DEEP_LIST_ENTRY, arg_name, fatal=False) if entry_flag is None: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_METRIC_UNEXPECTED_ARG % arg_name) else: args_dict[entry_flag] = arg_name self.__check_essential_args(args_dict) return { arg_name: entry_flag for entry_flag, arg_name in args_dict.items() }
def __check_load_as(self, load_as: Union[str, int, Flag, None]) -> Flag: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check the data type If the data type given is None we try to estimate it (errors can occur with complex types) Else we directly get the data type given by the user PARAMETERS: ----------- :param load_as (Union[str, int, None]): The data type in a raw format given by the user RETURN: ------- :return load_as(Flag): The data type of the entry """ if load_as is None: # Get an instance instance_example, is_loaded, _ = self.data_entry( ).__get_first_item() if is_loaded is True: load_as = None else: # Automatically check the data type load_as = self.__estimate_load_as(instance_example) else: load_as = get_corresponding_flag(flag_list=DEEP_LIST_LOAD_AS, info=load_as) return load_as
def forward(self, flag, model, outputs, labels, inputs=None, additional_data=None): flag = get_corresponding_flag(DEEP_LIST_DATASET, flag, fatal=False) losses = {} for loss_name in self.names: losses[loss_name] = self.__dict__[loss_name].forward( model=model, outputs=outputs, labels=labels, inputs=inputs, additional_data=additional_data) self.update_values(self.values[flag.name.lower()], losses) loss = sum([value for _, value in losses.items()]) losses = { loss_name: value.item() for loss_name, value in losses.items() } return loss, losses
def __check_entry_type(entry_type: Union[str, int, Flag]) -> Flag: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check the entry type PARAMETERS: ----------- :param entry_type (Union[str, int, Flag]): The raw entry type RETURN: ------- :return entry_type(Flag): The entry type """ return get_corresponding_flag(flag_list=DEEP_LIST_ENTRY, info=entry_type)
def __init__(self, name: str = TOTAL_LOSS, condition: Union[Flag, int, str, None] = DEEP_SAVE_CONDITION_LESS): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Initialize the OverWatchMetric instance PARAMETERS: ----------- :param name (str): The name of the metric to over watch :param condition (Flag): """ self.name = name self.value = 0.0 self.condition = get_corresponding_flag(flag_list=DEEP_LIST_SAVE_CONDITIONS, info=condition, fatal=False, default=DEEP_SAVE_CONDITION_LESS)
def __init__(self, metrics: dict, losses: dict, log_dir: str = "history", train_batches_filename: str = "history_batches_training.csv", train_epochs_filename: str = "history_epochs_training.csv", validation_filename: str = "history_validation.csv", verbose: Flag = DEEP_VERBOSE_BATCH, memorize: Flag = DEEP_MEMORIZE_BATCHES, save_signal: Flag = DEEP_SAVE_SIGNAL_END_EPOCH, overwatch_metric: OverWatchMetric = OverWatchMetric( name=TOTAL_LOSS, condition=DEEP_SAVE_CONDITION_LESS)): self.log_dir = log_dir self.verbose = verbose self.metrics = metrics self.losses = losses self.memorize = get_corresponding_flag( [DEEP_MEMORIZE_BATCHES, DEEP_MEMORIZE_EPOCHS], info=memorize) self.save_signal = save_signal self.overwatch_metric = overwatch_metric # Running metrics self.running_total_loss = 0 self.running_losses = {} self.running_metrics = {} self.train_batches_history = multiprocessing.Manager().Queue() self.train_epochs_history = multiprocessing.Manager().Queue() self.validation_history = multiprocessing.Manager().Queue() # Add headers to history files train_batches_headers = ",".join( [WALL_TIME, RELATIVE_TIME, EPOCH, BATCH, TOTAL_LOSS] + list(vars(losses).keys()) + list(vars(metrics).keys())) train_epochs_headers = ",".join( [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] + list(vars(losses).keys()) + list(vars(metrics).keys())) validation_headers = ",".join( [WALL_TIME, RELATIVE_TIME, EPOCH, TOTAL_LOSS] + list(vars(losses).keys()) + list(vars(metrics).keys())) # Create the history files self.__add_logs("history_train_batches", log_dir, ".csv", train_batches_headers) self.__add_logs("history_train_epochs", log_dir, ".csv", train_epochs_headers) self.__add_logs("history_validation", log_dir, ".csv", validation_headers) self.start_time = 0 self.paused = False # Filepaths self.log_dir = log_dir self.train_batches_filename = train_batches_filename self.train_epochs_filename = train_epochs_filename self.validation_filename = validation_filename # Load histories self.__load_histories() # Connect to signals Thalamus().connect(receiver=self.on_batch_end, event=DEEP_EVENT_ON_BATCH_END, expected_arguments=[ "minibatch_index", "num_minibatches", "epoch_index", "total_loss", "result_losses", "result_metrics" ]) Thalamus().connect(receiver=self.on_epoch_end, event=DEEP_EVENT_ON_EPOCH_END, expected_arguments=[ "epoch_index", "num_epochs", "num_minibatches", "total_validation_loss", "result_validation_losses", "result_validation_metrics", "num_minibatches_validation" ]) Thalamus().connect(receiver=self.on_train_begin, event=DEEP_EVENT_ON_TRAINING_START, expected_arguments=[]) Thalamus().connect(receiver=self.on_train_end, event=DEEP_EVENT_ON_TRAINING_END, expected_arguments=[]) Thalamus().connect(receiver=self.on_epoch_start, event=DEEP_EVENT_ON_EPOCH_START, expected_arguments=["epoch_index", "num_epochs"]) Thalamus().connect(receiver=self.send_training_loss, event=DEEP_EVENT_REQUEST_TRAINING_LOSS, expected_arguments=[])
def __create_transformer(self, config_entry): """ CONTRIBUTORS: ------------- Creator : Alix Leroy DESCRIPTION: ------------ Create the adequate transformer PARAMETERS: ----------- :param config: The transformer config :param pointer-> bool : Whether or not the transformer points to another transformer RETURN: ------- :return transformer: The created transformer """ transformer = None # NONE if config_entry is None: transformer = NoTransformer() # POINTER elif self.__is_pointer(config_entry) is True: transformer = Pointer( config_entry) # Generic Transformer as a pointer # TRANSFORMER else: config = Namespace(config_entry) # Check if a method is given by the user if config.check("method", None) is False: Notification( DEEP_NOTIF_FATAL, "The following transformer does not have any method specified : " + str(config_entry)) # Get the corresponding flag flag = get_corresponding_flag(flag_list=DEEP_LIST_TRANSFORMERS, info=config.method, fatal=False) # # Create the corresponding Transformer # try: # SEQUENTIAL if DEEP_TRANSFORMER_SEQUENTIAL.corresponds(flag): transformer = Sequential(**config.get(ignore="method")) # ONE OF elif DEEP_TRANSFORMER_ONE_OF.corresponds(flag): transformer = OneOf(**config.get(ignore="method")) # SOME OF elif DEEP_TRANSFORMER_SOME_OF.corresponds(flag): transformer = SomeOf(**config.get(ignore="method")) # If the method does not exist else: Notification( DEEP_NOTIF_FATAL, "Unknown transformer method specified in %s : %s" % (config_entry, config.method), solutions=[ "Ensure a valid transformer method is specified in %s" % config_entry ]) except TypeError as e: Notification( DEEP_NOTIF_FATAL, "TypeError when loading transformer : %s : %s" % (config_entry, e), solutions=["Check the syntax of %s" % config_entry]) return transformer
def reduce(self, new_value): self._reduce = get_corresponding_flag(DEEP_LIST_REDUCE, new_value) self.__set_reduce_method()
def __create_transformer(self, config_entry): """ CONTRIBUTORS: ------------- Creator : Alix Leroy DESCRIPTION: ------------ Create the adequate transformer PARAMETERS: ----------- :param config: The transformer config :param pointer-> bool : Whether or not the transformer points to another transformer RETURN: ------- :return transformer: The created transformer """ transformer = None # NONE if config_entry is None: transformer = NoTransformer() # POINTER elif self.__is_pointer(config_entry) is True: transformer = Pointer( config_entry) # Generic Transformer as a pointer # TRANSFORMER else: config = Namespace(config_entry) # Check if a method is given by the user if config.check("method", None) is False: Notification( DEEP_NOTIF_FATAL, "The following transformer does not have any method specified : " + str(config_entry)) # Get the corresponding flag flag = get_corresponding_flag(flag_list=DEEP_LIST_TRANSFORMERS, info=config.method, fatal=False) # Remove the method from the config delattr(config, 'method') # # Create the corresponding Transformer # # SEQUENTIAL if DEEP_TRANSFORMER_SEQUENTIAL.corresponds(flag): transformer = Sequential(**config.get()) # ONE OF elif DEEP_TRANSFORMER_ONE_OF.corresponds(flag): transformer = OneOf(**config.get()) # SOME OF elif DEEP_TRANSFORMER_SOME_OF.corresponds(flag): SomeOf(**config.get()) # If the method does not exist else: Notification( DEEP_NOTIF_FATAL, "The following transformation method does not exist : " + str(config.method)) return transformer
def __init__(self, model: nn.Module, dataset: Dataset, metrics: dict, losses: dict, optimizer, num_epochs: int, initial_epoch: int = 1, batch_size: int = 4, shuffle_method: Flag = DEEP_SHUFFLE_NONE, num_workers: int = 4, verbose: Flag = DEEP_VERBOSE_BATCH, tester: Tester = None) -> None: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Initialize a Trainer instance PARAMETERS: ----------- :param model (torch.nn.Module): The model which has to be trained :param dataset (Dataset): The dataset to be trained on :param metrics (dict): The metrics to analyze :param losses (dict): The losses to use for the backpropagation :param optimizer: The optimizer to use for the backpropagation :param num_epochs (int): Number of epochs for the training :param initial_epoch (int): The index of the initial epoch :param batch_size (int): Size a minibatch :param shuffle_method (Flag): DEEP_SHUFFLE flag, method of shuffling to use :param num_workers (int): Number of processes / threads to use for data loading :param verbose (int): DEEP_VERBOSE flag, How verbose the Trainer is :param memorize (int): DEEP_MEMORIZE flag, what data to save :param save_condition (int): DEEP_SAVE flag, when to save the results :param tester (Tester): The tester to use for validation :param model_name (str): The name of the model RETURN: ------- :return: None """ # Initialize the GenericEvaluator par super().__init__(model=model, dataset=dataset, metrics=metrics, losses=losses, batch_size=batch_size, num_workers=num_workers, verbose=verbose) self.optimizer = optimizer self.initial_epoch = initial_epoch self.epoch = None self.validation_loss = None self.num_epochs = num_epochs # Load shuffling method self.shuffle_method = get_corresponding_flag(DEEP_LIST_SHUFFLE, shuffle_method, fatal=False, default=DEEP_SHUFFLE_NONE) if isinstance(tester, Tester): self.tester = tester # Tester for validation self.tester.set_metrics(metrics=metrics) self.tester.set_losses(losses=losses) else: self.tester = None # Early stopping # self.stopping = Stopping(stopping_parameters) # # Connect signals # Thalamus().connect(receiver=self.saving_required, event=DEEP_EVENT_SAVING_REQUIRED, expected_arguments=["saving_required"]) Thalamus().connect(receiver=self.send_save_params, event=DEEP_EVENT_REQUEST_SAVE_PARAMS_FROM_TRAINER, expected_arguments=[])