def load_predictor(self): # If the predict step is enabled if self.config.data.enabled.predict: self.loading_message("Predictor") i = self.get_dataset_index(DEEP_DATASET_PREDICTION) Notification( DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name) # Input Transform Manager transform_manager = TransformManager( **self.config.transform.predict.get(ignore="outputs")) # Output Transform Manager output_transform_manager = OutputTransformer( transform_files=self.config.transform.predict.get("outputs")) # Initialise prediction dataset dataset = Dataset(**self.config.data.datasets[i].get( ignore=["batch_size"]), transform_manager=transform_manager) # Initialise predictor #self.predictor = Predictor( # **self.config.data.dataloader.get(), # batch_size = self.config.data.datasets[i].batch_size, # name="Predictor", # model=self.model, # dataset=dataset, # transform_manager=output_transform_manager #) else: Notification(DEEP_NOTIF_INFO, DEEP_MSG_DATA_DISABLED % DEEP_DATASET_PREDICTION.name)
def fit(self, first_training: bool = True) -> None: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Fit the model to the dataset PARAMETERS: ----------- :param first_training: (bool, optional): Whether it is the first training on the model or not RETURN: ------- :return: None """ Notification(DEEP_NOTIF_INFO, DEEP_MSG_TRAINING_STARTED) self.__train(first_training=first_training) Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_TRAINING_FINISHED)
def summary(self) -> None: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Print the summary of the Pointer PARAMETERS: ----------- None RETURN: ------- :return: None """ Notification(DEEP_NOTIF_INFO, "------------------------------------") Notification(DEEP_NOTIF_INFO, "Transformer '" + str(self.name) + "' summary :") Notification( DEEP_NOTIF_INFO, "Points to the %ith %s." % (self.transformer_index, self.transformer_entry.get_name())) Notification(DEEP_NOTIF_INFO, "------------------------------------")
def __import_cv_library(self): """ AUTHORS: -------- author: Samuel Westlake DESCRIPTION: ------------ Imports either cv2 or PIL.Image dependant on the value of self.cv_library PARAMETERS: ----------- None RETURN: ------- None """ if self.cv_library == DEEP_LIB_OPENCV: try: Notification(DEEP_NOTIF_INFO, DEEP_MSG_CV_LIBRARY_SET % "OPENCV") global cv2 import cv2 except ImportError as e: Notification(DEEP_NOTIF_ERROR, str(e)) elif self.cv_library == DEEP_LIB_PIL: try: Notification(DEEP_NOTIF_INFO, DEEP_MSG_CV_LIBRARY_SET % "PILLOW") global Image from PIL import Image except ImportError as e: Notification(DEEP_NOTIF_ERROR, str(e)) else: Notification(DEEP_NOTIF_ERROR, DEEP_MSG_CV_LIBRARY_NOT_IMPLEMENTED % self.cv_library)
def load_transform_functions(self): """ Calls the get_module function for each transform in each transformer Loads the returned method and edits the module entry to reflect the origin :return: None """ # For each transform sequence for i, sequence in enumerate(self.output_transformer): # Check that the transforms entry exists self.__check_transforms_exists(sequence, i) if sequence.transforms is not None: for transform_name, transform_info in sequence.transforms.get( ).items(): self.__transform_loading(transform_name, transform_info) method, module_path = get_module( **transform_info.get(ignore="kwargs"), browse=DEEP_MODULE_TRANSFORMS) if method is None: self.__method_not_found(transform_info) if isinstance(method, types.FunctionType): transform_info.add({"method": method}) else: try: transform_info.add({ "method": method(**transform_info.kwargs.get()) }) except TypeError as e: Notification(DEEP_NOTIF_FATAL, str(e)) transform_info.module = module_path Notification( DEEP_NOTIF_SUCCESS, "Loaded transform : %s : %s from %s" % (transform_name, transform_info.name, transform_info.module))
def __thanks_master(): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Display thanks message PARAMETERS: ----------- None RETURN: ------- :return: Universal Love <3 """ Notification(DEEP_NOTIF_INFO, "=================================") Notification(DEEP_NOTIF_INFO, "Thank you for using Deeplodocus !") Notification(DEEP_NOTIF_INFO, "== Made by Humans with deep <3 ==") Notification(DEEP_NOTIF_INFO, "=================================")
def train(self, num_epochs: Union[int, None] = None): # Pre-training checks if self.model is None: Notification( DEEP_NOTIF_ERROR, "Could not begin training : No model detected by the trainer") if self.losses is None: Notification( DEEP_NOTIF_ERROR, "Could not begin training : No losses detected by the trainer") if self.optimizer is None: Notification( DEEP_NOTIF_ERROR, "Could not begin training : No optimizer detected by the trainer" ) # Update num_epochs self.num_epochs = self.num_epochs if num_epochs is None else num_epochs # Infer initial epoch if self.initial_epoch is None: self.initial_epoch = self.model.epoch if "epoch" in vars( self.model).keys() else 0 # Go self.training_start() for self.epoch in range(self.initial_epoch + 1, self.num_epochs + self.initial_epoch + 1): self.epoch_start() for self.batch_index, batch in enumerate(self.dataloader, 1): self.forward(batch) if self.accumulate == 1 else self.forward2( batch) self.epoch_end() self.training_end()
def summary(): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Print the summary of the NoTransformer PARAMETERS: ----------- None RETURN: ------- :return: None """ Notification(DEEP_NOTIF_INFO, "------------------------------------") Notification(DEEP_NOTIF_INFO, "No Transformer for this entry") Notification(DEEP_NOTIF_INFO, "------------------------------------")
def __set_number(self) -> None: """ AUTHORS: -------- author: Samuel Westlake, Alix Leroy DESCRIPTION: ------------ Set the length of the dataset RETURN: ------- :return: None """ if self.number is None: self.number = len(self.data) Notification(DEEP_NOTIF_INFO, DEEP_MSG_DATA_NO_LENGTH % len(self.data)) else: if self.number > len(self.data): self.number = len(self.data) Notification( DEEP_NOTIF_WARNING, DEEP_MSG_DATA_TOO_LONG % (self.number, len(self.data))) else: Notification( DEEP_NOTIF_INFO, DEEP_MSG_DATA_LENGTH % (len(self.data), self.number))
def shuffle(self, method: int) -> None: """ AUTHORS: -------- author: Alix Leroy DESCRIPTION: ------------ Shuffle the dataframe containing the data PARAMETERS: ----------- None RETURN: ------- :return: None """ if method == DEEP_SHUFFLE_ALL: try: self.data = self.data.sample(frac=1).reset_index(drop=True) # TODO: Please can this except a specific error(s) except: Notification(DEEP_NOTIF_ERROR, "Cannot shuffle the dataset") else: Notification(DEEP_NOTIF_ERROR, "The shuffling method does not exist.") # Reset the TransformManager self.reset()
def generate_sources(self, sources: List[dict]) -> None: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Generate the sources Does not generate the SourcePointer instances PARAMETERS: ----------- :param sources(List[dict]): The configuration of the Source instances to generate RETURN: ------- :return: None """ list_sources = [] # Create sources for i, source in enumerate(sources): # Get the Source module and add it to the list m = "default modules" if source["module"] is None else source["module"] Notification(DEEP_NOTIF_INFO, DEEP_MSG_LOADING % ("source [%i]" % i, source["name"], m)) method, module_path = get_module(name=source["name"], module=source["module"], browse=DEEP_MODULE_SOURCES) if method is None: Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_MODULE_NOT_FOUND % (source["name"], module_path)) else: Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_LOADED % ("source [%i]" % i, source["name"], module_path)) # If the source is not a real source if not issubclass(method, Source): # Remove the id from the initial kwargs index = source["kwargs"].pop('index', None) # Create a source wrapper with the new ID s = SourceWrapper( index=index, name=source["name"], module=source["module"], kwargs=source["kwargs"] ) else: # If the subclass is a real Source s = method(**source["kwargs"]) # Check the module inherits the generic Source class self.check_type_sources(s, i) # Add the module to the list of Source instances list_sources.append(s) # Set the list as the attribute self.sources = list_sources
def __check_directory(self) -> None: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check the given directory exists PARAMETERS: ----------- None RETURN: ------- :return: None """ if not os.path.isdir(self.path): Notification( DEEP_NOTIF_FATAL, "The following path is not a Source folder : %s " % self.path) else: Notification(DEEP_NOTIF_SUCCESS, "Source folder \"%s\" successfully found" % self.path)
def load_optimizer(self): """ AUTHORS: -------- :author: Samuel Westlake :author: Alix Leroy DESCRIPTION: ------------ Load the optimizer with the adequate parameters PARAMETERS: ----------- None RETURN: ------- :return: None """ if self.model is not None: self.optimizer = Optimizer(self.model.parameters(), self.config.optimizer).get() Notification( DEEP_NOTIF_SUCCESS, DEEP_MSG_OPTIMIZER_LOADED % (self.config.optimizer.name, self.optimizer.__module__)) else: Notification(DEEP_NOTIF_ERROR, DEEP_MSG_OPTIMIZER_NOT_LOADED % DEEP_MSG_MODEL_LOADED)
def load_predictor(self): # If the predict step is enabled if self.config.data.enabled.predict: predict_index = self.get_dataset_index("predict") Notification( DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[predict_index].name) # Input Transform Manager transform_manager = TransformManager( **self.config.transform.predict.get(ignore="outputs")) # Output Transform Manager output_transform_manager = OutputTransformer( transform_files=self.config.transform.predict.get("outputs")) # Dataset dataset = Dataset(**self.config.data.datasets[predict_index].get( ignore="type"), transform_manager=transform_manager) # Predictor self.predictor = Predictor( **self.config.data.dataloader.get(), model=self.model, dataset=dataset, transform_manager=output_transform_manager) else: Notification(DEEP_NOTIF_INFO, DEEP_MSG_DATA_DISABLED % "Prediction set")
def __import(self, *args): for arg in args: items = [] for m in DEEP_LIST_MODULE: method, module_path = get_module(arg, browse=m) if method is not None: items.append({ "method": method, "module_path": module_path, "file": inspect.getsourcefile(method), "prefix": m["custom"]["prefix"], "name": arg }) if not items: Notification(DEEP_NOTIF_ERROR, "module not found : %s" % arg) elif len(items) == 1: self.__import_module(items[0]) else: Notification(DEEP_NOTIF_INFO, "Multiple modules found with the same name :") for i, item in enumerate(items): Notification( DEEP_NOTIF_INFO, " %i. %s from %s" % (i, arg, item["module_path"])) while True: i = Notification(DEEP_NOTIF_INPUT, "Which would you like?").get() try: i = int(i) break except ValueError: pass self.__import_module(items[i])
def __check_args(self) -> dict: if inspect.isfunction(self.method): args = inspect.getfullargspec(self.method).args else: args = inspect.getfullargspec(self.method.forward).args args.remove("self") args_dict = {} for arg in args: if self.is_custom: entry_flag = get_corresponding_flag(DEEP_LIST_ENTRY, arg, fatal=False) if entry_flag is None: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_LOSS_UNEXPECTED_ARG % arg) else: args_dict[entry_flag] = arg else: if arg in ("input", "x", "inputs"): args_dict[DEEP_ENTRY_OUTPUT] = arg elif arg in ("y", "y_hat", "label", "labels", "target", "targets"): args_dict[DEEP_ENTRY_LABEL] = arg else: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_LOSS_UNEXPECTED_ARG % arg) self.check_essential_args(args_dict) return { arg_name: entry_flag for entry_flag, arg_name in args_dict.items() }
def summary(self): Notification(DEEP_NOTIF_INFO, '================================================================') Notification(DEEP_NOTIF_INFO, "OPTIMIZER : %s from %s" % (self.name, self.module)) Notification(DEEP_NOTIF_INFO, '================================================================') for key, value in self.kwargs.items(): if key != "name": Notification(DEEP_NOTIF_INFO, "%s : %s" % (key, value)) Notification(DEEP_NOTIF_INFO, "")
def summary(self): Notification(DEEP_NOTIF_INFO, "------------------------------------") Notification(DEEP_NOTIF_INFO, "Transformer '" + str(self.name) + "' summary :") Notification(DEEP_NOTIF_INFO, "Points to : " + str(self.pointer_to_transformer)) Notification(DEEP_NOTIF_INFO, "------------------------------------")
def __check_move_axis( self, move_axis: Optional[List[int]]) -> Optional[List[int]]: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check that the move_axis argument is a list of integer or None PARAMETERS: ----------- :param move_axis (Optional[List[int]]): The new order of axis RETURN: ------- :return (Optional[List[int]]): """ # Check if None if move_axis is None: return None if isinstance(move_axis, list): for i in move_axis: if isinstance(i, int) is False: # Get index of the PipelineEntry instance (throught its weakref) entry_info = self.pipeline_entry().get_index() # Get info of the Dataset (throught its weakref) from the weakref of the PipelineEntry instance dataset_info = self.pipeline_entry().get_dataset()( ).get_info() Notification( DEEP_NOTIF_FATAL, "Please check the value of the move_axis argument. in the Entry %i of the Dataset %s" % (entry_info, dataset_info), solutions= "One of the item in the list is not an integer") return move_axis else: # Get index of the PipelineEntry instance (throught its weakref) entry_info = self.pipeline_entry().get_index() # Get info of the Dataset (throught its weakref) from the weakref of the PipelineEntry instance dataset_info = self.pipeline_entry().get_dataset().get_info() Notification( DEEP_NOTIF_FATAL, "Please check the value of the move_axis argument in the Entry %i of the Dataset %s" % (entry_info, dataset_info), solutions= "The move_axis argument must be a list of integers index at 0")
def load_optimizer(self): """ AUTHORS: -------- :author: Samuel Westlake :author: Alix Leroy DESCRIPTION: ------------ Load the optimizer with the adequate parameters PARAMETERS: ----------- None RETURN: ------- :return: None """ # If a module is specified, edit the optimizer name to include the module (for notification purposes) if self.config.optimizer.module is None: optimizer_path = "%s from default modules" % self.config.optimizer.name else: optimizer_path = "%s from %s" % (self.config.optimizer.name, self.config.optimizer.module) # Notify the user which model is being collected and from where Notification(DEEP_NOTIF_INFO, DEEP_MSG_OPTIM_LOADING % optimizer_path) # An optimizer cannot be loaded without a model (self.model.parameters() is required) if self.model is not None: # Load the optimizer optimizer = load_optimizer( model_parameters=self.model.parameters(), **self.config.optimizer.get()) msg = "%s from %s" % (self.config.optimizer.name, optimizer.module) if self.config.model.from_file: checkpoint = self.__load_checkpoint() if "optimizer_state_dict" in checkpoint: optimizer.load_state_dict( checkpoint["optimizer_state_dict"]) msg += " with state dict from %s" % self.config.model.file # If model exists, load the into the frontal lobe if optimizer is None: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_OPTIM_NOT_FOUND % optimizer_path) else: self.optimizer = optimizer Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_OPTIM_LOADED % msg) # Notify the user that a model must be loaded else: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_OPTIM_MODEL_NOT_LOADED)
def load_validator(self): """ AUTHORS: -------- :author: Alix Leroy :author: Samuel Westlake DESCRIPTION: ------------ Load the validation inferer in memory PARAMETERS: ----------- None RETURN: ------- :return: None """ # If the validation step is enabled if self.config.data.enabled.validation: validation_index = self.get_dataset_index("validation") Notification( DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[validation_index].name) # Transform Manager transform_manager = TransformManager( **self.config.transform.validation.get(ignore="outputs")) # Output Transformer output_transform_manager = OutputTransformer( transform_files=self.config.transform.validation.get( "outputs")) # output_transformer.summary() # Dataset dataset = Dataset(**self.config.data.datasets[validation_index]. get(ignore="type"), transform_manager=transform_manager) # Validator self.validator = Tester(**self.config.data.dataloader.get(), model=self.model, dataset=dataset, metrics=self.metrics, losses=self.losses, transform_manager=output_transform_manager) else: Notification(DEEP_NOTIF_INFO, DEEP_MSG_DATA_DISABLED % "Validation set")
def __run_project(self): Notification(DEEP_NOTIF_WARNING, "The run project command does not support custom modules") Notification(DEEP_NOTIF_WARNING, "If you need to import custom modules use : python3 main.py") try: config_dir = self.argv[2] except IndexError: config_dir = "./config" brain = Brain(config_dir=config_dir) brain.wake()
def load_metrics(self): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Load the metrics PARAMETERS: ----------- None RETURN: ------- :return loss_functions->dict: The metrics """ metrics = {} if self.config.metrics.get(): for key, config in self.config.metrics.get().items(): # Get the expected loss path (for notification purposes) if config.module is None: metric_path = "%s : %s from default modules" % (key, config.name) else: metric_path = "%s : %s from %s" % (key, config.name, config.module) # Notify the user which loss is being collected and from where Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_LOADING % metric_path) # Get the metric object metric, module = get_module( name=config.name, module=config.module, browse={**DEEP_MODULE_METRICS, **DEEP_MODULE_LOSSES}, silence=True ) # If metric is not found by get_module if metric is None: Notification(DEEP_NOTIF_FATAL, DEEP_MSG_METRIC_NOT_FOUND % config.name) # Check if the metric is a class or a stand-alone function if inspect.isclass(metric): method = metric(**config.kwargs.get()) else: method = metric # Add to the dictionary of metrics and notify of success metrics[str(key)] = Metric(name=str(key), method=method) metrics[str(key)] = Metric(name=str(key), method=method) Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_METRIC_LOADED % (key, config.name, module)) else: Notification(DEEP_NOTIF_INFO, DEEP_MSG_METRIC_NONE) self.metrics = Metrics(metrics)
def load_tester(self): """ AUTHORS: -------- :author: Alix Leroy :author: Samuel Westlake DESCRIPTION: ------------ Load the test inferer in memory PARAMETERS: ----------- None RETURN: ------- :return: None """ # If the test step is enabled if self.config.data.enabled.test: self.loading_message("Tester") i = self.get_dataset_index(DEEP_DATASET_TEST) Notification( DEEP_NOTIF_INFO, DEEP_NOTIF_DATA_LOADING % self.config.data.datasets[i].name) # Input Transform Manager transform_manager = TransformManager( **self.config.transform.test.get(ignore="outputs")) # Output Transformer output_transform_manager = OutputTransformer( transform_files=self.config.transform.test.get("outputs")) # Initialise test dataset dataset = Dataset(**self.config.data.datasets[i].get( ignore=["batch_size"]), transform_manager=transform_manager) # Initialise tester self.tester = Tester( **self.config.data.dataloader.get(), batch_size=self.config.data.datasets[i].batch_size, model=self.model, dataset=dataset, metrics=self.metrics, losses=self.losses, transform_manager=output_transform_manager) else: Notification(DEEP_NOTIF_INFO, DEEP_MSG_DATA_DISABLED % DEEP_DATASET_TEST.name)
def load_losses(self): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Load the losses PARAMETERS: ----------- None RETURN: ------- :return loss_functions->dict: The losses """ losses = {} for key, value in self.config.losses.get().items(): loss = get_module(config=value, modules=DEEP_MODULE_LOSSES) kwargs = check_kwargs(get_kwargs(self.config.losses.get(key))) method = loss(**kwargs) # Check the weight if self.config.losses.check("weight", key): if get_int_or_float(value.weight) not in (DEEP_TYPE_INTEGER, DEEP_TYPE_FLOAT): Notification( DEEP_NOTIF_FATAL, "The loss function %s doesn't have a correct weight argument" % key) else: Notification( DEEP_NOTIF_FATAL, "The loss function %s doesn't have any weight argument" % key) # Create the loss if isinstance(method, torch.nn.Module): losses[str(key)] = Loss(name=str(key), weight=float(value.weight), loss=method) Notification( DEEP_NOTIF_SUCCESS, DEEP_MSG_LOSS_LOADED % (key, value.name, loss.__module__)) else: Notification( DEEP_NOTIF_FATAL, "The loss function %s is not a torch.nn.Module instance" % key) self.losses = losses
def __check_load_method(self, load_method): """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check if the load method required is valid. Check if the load method given is an integer, otherwise convert it to the corresponding flag PARAMETERS: ----------- :param load_method: RETURN: ------- :return load_method (int): The corresponding DEEP_LOAD_METHOD flag """ if isinstance(load_method, int): if load_method in DEEP_LOAD_METHOD_LIST: return load_method else: Notification( DEEP_NOTIF_FATAL, "The given flag '%i' does not correspond to any DEEP_LOAD_METHOD" % load_method) else: if load_method == "default": return DEEP_LOAD_METHOD_MEMORY elif load_method == "memory": return DEEP_LOAD_METHOD_MEMORY elif load_method in [ "harddrive", "hard drive", "hard-drive", "hard_drive" ]: Notification( DEEP_NOTIF_FATAL, "Loading data using a hard drive reading is not currently implemented" ) return DEEP_LOAD_METHOD_HARDDRIVE elif load_method == "server": Notification( DEEP_NOTIF_FATAL, "Loading data using a server is not currently implemented") return DEEP_LOAD_METHOD_SERVER else: Notification( DEEP_NOTIF_FATAL, "The following loading method does not exist : %s" % str(load_method))
def __check_sources( self, sources: Union[str, List[str]], join: Union[str, None, List[Union[str, None]]]) -> List[Source]: """ AUTHORS: -------- :author: Alix Leroy DESCRIPTION: ------------ Check the source given is a string or a list of strings If it is a single source, we transform it to a list of a single element PARAMETERS: ----------- :param sources (Union[str, List[str]]) : The source or list of sources :param join (Union[str, List[str]) : The list of jointures RETURN: ------- :return sources(List[Source]): List of sources """ formatted_sources = [] # Convert single elements to a list of one element if isinstance(sources, list) is False: sources = [sources] # Check joins join = self.__check_join(join=join) # Check the number of jointures and sources are equal: if len(sources) != len(join): Notification( DEEP_NOTIF_FATAL, "The entry %s (index %i) does not have the same amount of sources and jointures" % (self.entry_type.get_description(), self.entry_index)) # Check all the elements in the list to make sure they are strings for i, s in enumerate(sources): if isinstance(s, str) is True: source = Source(source=s, join=join[i]) formatted_sources.append(source) else: Notification( DEEP_NOTIF_FATAL, "The source parameter '%s' in the %i th %s entry is not valid" % (str(s), self.entry_type.get_description(), self.entry_index)) return formatted_sources
def summary(self, level=0): Notification(DEEP_NOTIF_INFO, "%sname: %s" % (" " * level, self.name)) Notification(DEEP_NOTIF_INFO, "%smodule path: %s" % (" " * level, self.module_path)) Notification(DEEP_NOTIF_INFO, "%sweight: %s" % (" " * level, self.f2str(self.weight))) if bool(self.args): Notification(DEEP_NOTIF_INFO, "%sargs: " % " " * level) for arg_name, entry_flag in self.args.items(): Notification( DEEP_NOTIF_INFO, "%s%s (%s)" % (" " * (level + 1), arg_name, entry_flag.name)) else: Notification( DEEP_NOTIF_INFO, "%sargs: []" % " " * level, ) if bool(self.kwargs): Notification(DEEP_NOTIF_INFO, "%skwargs: " % " " * level) for kwarg_name, kwarg_value in self.kwargs.items(): Notification( DEEP_NOTIF_INFO, "%s%s: %s" % (" " * (level + 1), kwarg_name, kwarg_value)) else: Notification(DEEP_NOTIF_INFO, "%skwargs: {}" % (" " * level))
def load_weights(self, weights): if weights is not None: Notification(DEEP_NOTIF_INFO, "Initializing network weights from file") try: self.load_state_dict(torch.load(weights)['model_state_dict'], strict=True) Notification(DEEP_NOTIF_SUCCESS, "Successfully loaded weights from %s" % weights) except RuntimeError as e: Notification(DEEP_NOTIF_WARNING, " :" .join(str(e).split(":")[1:]).strip()) Notification(DEEP_NOTIF_WARNING, "Ignoring unexpected key(s)") self.load_state_dict(torch.load(weights)['model_state_dict'], strict=False) Notification(DEEP_NOTIF_SUCCESS, "Successfully loaded weights from %s" % weights)
def evaluation_end(self, silent: bool = False): self.transform_manager.finish() # Call finish on all output transforms loss, losses = self.losses.reduce( self.dataset.type) # Get total loss and mean of each loss metrics = self.metrics.reduce( self.dataset.type) # Get total metric values if not silent: Notification(DEEP_NOTIF_SUCCESS, DEEP_MSG_EVALUATION_FINISHED) Notification(DEEP_NOTIF_RESULT, self.compose_text(loss, losses, metrics)) return loss, losses, metrics