def remove_stride(trainer): global superfluous_strides global model n_convs = len(model['channels']) - 1 for i in range(superfluous_strides): rgetattr(trainer.model, f'conv{n_convs - i}.convolution').stride = (1, 1) rgetattr(trainer.model, f'dconv{i+1}.upsampling').scale_factor = 1.0 return trainer
def save(self, *args, save_config=None, checkpoint_name=None, directory='./', **kwargs): """ create a checkpoint with all objects of the types specified in the save config :param save_config: dict that specifies which attributes to save. Keys should be classes and values conversion methods. :param checkpoint_name: name of the checkpoint that will be created :param directory: folder where checkpoints will be saved :return: filename of the saved checkpoint """ if save_config is None: save_config = self.default_save_config().copy() # find all objects that need to be saved objects = {} for key, value in self.__dict__.items(): # check if value is instance of classes in save_config cls = None conversion_name = None for obj, call in save_config.items(): if isinstance(value, obj): cls = obj conversion_name = call break # apply conversion method specified in save_config for that class if cls is not None and conversion_name is not None: # find conversion method for present value for level in (value, self, object, type): conversion = rgetattr(level, conversion_name, None) if conversion is not None: break # convert the value for storage objects[key] = conversion() if level is value else conversion( value) # set default save path if it is None if checkpoint_name is not None: checkpoint = f'checkpoint_{checkpoint_name}.pt' else: checkpoint = f'checkpoint_{ctime().replace(" ", "-")}.pt' # save the objects pt.save(objects, os.path.join(directory, checkpoint)) return checkpoint
def load(self, checkpoint, map_location=None, **kwargs): """ load a checkpoint that can contain model and optimizer state :param checkpoint: filename of the checkpoint :param map_location: device to load models to :return: None """ checkpoint = pt.load(checkpoint, map_location=map_location) # load components for key in checkpoint.keys(): if not hasattr(self, key) or not checkpoint[key]: continue print(f'loading {key} checkpoint...', end='') if key == 'model' or key == 'optimizer': rgetattr(self, key).load_state_dict(checkpoint[key]) else: rsetattr(self, key, checkpoint[key]) print('done')
def monitor(self, *, name, method=None): """ wrap a funciton to intercept its results. If only a name is specified, method is retrieved from self and will be replaced with wrapped version. If name and method are specified wrapped function will be returned instead. :param name: name of the function / method :param method: callable to wrap. If this argument is specified, the passed callable will be wrapped and returned. Otherwilse, the method will be retrieved from self via rgetattr and replaced with its wrapped version. :return: wrapped callable if method is not None else None """ set_method = False if method is None: method = rgetattr(self, name) set_method = True method = self._wrap(method, name) if set_method: rsetattr(self, name, method) else: return method