def load_model_from_state(checkpoint: Path, state: dict): config = state.get("config", None) if config is None: config = get_config_for_old_checkpoint(checkpoint) config["checkpoint"] = str(checkpoint) model = get_model(**config["model"]) model.load_state_dict(state["model"], strict=True) return model, config
def __init__(self, wandb_run, config: PredictPathRunConfig, log_level_wandb: int): model: HyLFM_Net = get_model(**config.checkpoint.config.model) model.load_state_dict(config.checkpoint.model_weights, strict=True) self.wandb_run = wandb_run scale = model.get_scale() super().__init__( config=config, dataset_part=DatasetPart.predict, model=model, name=config.checkpoint.training_run_name, run_logger=WandbLogger( point_cloud_threshold=config.point_cloud_threshold, zyx_scaling=(5, 0.7 * 8 / scale, 0.7 * 8 / scale)), log_level_wandb=log_level_wandb, )
def __init__(self, *, wandb_run, checkpoint: Checkpoint): cfg = checkpoint.config model: HyLFM_Net = get_model(**cfg.model) if checkpoint.model_weights is not None: model.load_state_dict(checkpoint.model_weights, strict=True) self.wandb_run = wandb_run assert wandb_run.name == checkpoint.training_run_name scale = model.get_scale() super().__init__( config=checkpoint.config, dataset_part=DatasetPart.train, model=model, name=checkpoint.training_run_name, run_logger=WandbLogger( point_cloud_threshold=cfg.point_cloud_threshold, zyx_scaling=(5, 0.7 * 8 / scale, 0.7 * 8 / scale)), ) self.current_best_checkpoint_on_disk: Optional[Path] = None self.criterion = get_criterion( config=self.config, transforms_pipeline=self.transforms_pipeline) opt_class: Type[Optimizer] = getattr(torch.optim, self.config.optimizer.name) opt_kwargs = { "lr": self.config.opt_lr, "weight_decay": self.config.opt_weight_decay } if self.config.optimizer == OptimizerChoice.SGD: opt_kwargs["momentum"] = self.config.opt_momentum self.optimizer: Optimizer = opt_class(self.model.parameters(), **opt_kwargs) self.optimizer.zero_grad( ) # calling zero_grad here, because of how batch_multiplier is implemented in TrainRun if checkpoint.optimizer_state_dict is not None: self.optimizer.load_state_dict(checkpoint.optimizer_state_dict) if self.config.lr_scheduler is None: assert checkpoint.lr_scheduler_state_dict is None self.lr_scheduler = None else: sched_class: Type[LRScheduler] = getattr( torch.optim.lr_scheduler, self.config.lr_scheduler.name) if self.config.lr_scheduler == LRSchedulerChoice.ReduceLROnPlateau: sched_kwargs = dict( mode="min" if getattr(hylfm.metrics, cfg.score_metric.name).minimize else "max", factor=cfg.lr_sched_factor, patience=cfg.lr_sched_patience, threshold=cfg.lr_sched_thres, threshold_mode=cfg.lr_sched_thres_mode, cooldown=0, min_lr=1e-7, ) else: raise NotImplementedError self.lr_scheduler: LRScheduler = sched_class( self.optimizer, **sched_kwargs) if checkpoint.lr_scheduler_state_dict is not None: self.lr_scheduler.load_state_dict( checkpoint.lr_scheduler_state_dict) self.validator = ValidationRun( config=ValidationRunConfig( batch_size=cfg.eval_batch_size, data_range=cfg.data_range, dataset=cfg.dataset, interpolation_order=cfg.interpolation_order, save_output_to_disk={}, win_sigma=cfg.win_sigma, win_size=cfg.win_size, hylfm_version=cfg.hylfm_version, point_cloud_threshold=cfg.point_cloud_threshold, ), model=model, score_metric=cfg.score_metric, name=self.name, ) self.validate_every = Period(cfg.validate_every_value, cfg.validate_every_unit) self.epoch_len = len(self.dataloader) self.best_validation_score = checkpoint.best_validation_score self.epoch = checkpoint.epoch self.impatience = checkpoint.impatience self.iteration = checkpoint.iteration self.training_run_id = checkpoint.training_run_id self.validation_iteration = checkpoint.validation_iteration