示例#1
0
def load_model_from_state(checkpoint: Path, state: dict):
    config = state.get("config", None)
    if config is None:
        config = get_config_for_old_checkpoint(checkpoint)

    config["checkpoint"] = str(checkpoint)
    model = get_model(**config["model"])
    model.load_state_dict(state["model"], strict=True)
    return model, config
示例#2
0
    def __init__(self, wandb_run, config: PredictPathRunConfig,
                 log_level_wandb: int):
        model: HyLFM_Net = get_model(**config.checkpoint.config.model)
        model.load_state_dict(config.checkpoint.model_weights, strict=True)

        self.wandb_run = wandb_run
        scale = model.get_scale()
        super().__init__(
            config=config,
            dataset_part=DatasetPart.predict,
            model=model,
            name=config.checkpoint.training_run_name,
            run_logger=WandbLogger(
                point_cloud_threshold=config.point_cloud_threshold,
                zyx_scaling=(5, 0.7 * 8 / scale, 0.7 * 8 / scale)),
            log_level_wandb=log_level_wandb,
        )
示例#3
0
    def __init__(self, *, wandb_run, checkpoint: Checkpoint):
        cfg = checkpoint.config
        model: HyLFM_Net = get_model(**cfg.model)
        if checkpoint.model_weights is not None:
            model.load_state_dict(checkpoint.model_weights, strict=True)

        self.wandb_run = wandb_run
        assert wandb_run.name == checkpoint.training_run_name
        scale = model.get_scale()
        super().__init__(
            config=checkpoint.config,
            dataset_part=DatasetPart.train,
            model=model,
            name=checkpoint.training_run_name,
            run_logger=WandbLogger(
                point_cloud_threshold=cfg.point_cloud_threshold,
                zyx_scaling=(5, 0.7 * 8 / scale, 0.7 * 8 / scale)),
        )

        self.current_best_checkpoint_on_disk: Optional[Path] = None

        self.criterion = get_criterion(
            config=self.config, transforms_pipeline=self.transforms_pipeline)

        opt_class: Type[Optimizer] = getattr(torch.optim,
                                             self.config.optimizer.name)
        opt_kwargs = {
            "lr": self.config.opt_lr,
            "weight_decay": self.config.opt_weight_decay
        }
        if self.config.optimizer == OptimizerChoice.SGD:
            opt_kwargs["momentum"] = self.config.opt_momentum

        self.optimizer: Optimizer = opt_class(self.model.parameters(),
                                              **opt_kwargs)
        self.optimizer.zero_grad(
        )  # calling zero_grad here, because of how batch_multiplier is implemented in TrainRun

        if checkpoint.optimizer_state_dict is not None:
            self.optimizer.load_state_dict(checkpoint.optimizer_state_dict)

        if self.config.lr_scheduler is None:
            assert checkpoint.lr_scheduler_state_dict is None
            self.lr_scheduler = None
        else:
            sched_class: Type[LRScheduler] = getattr(
                torch.optim.lr_scheduler, self.config.lr_scheduler.name)
            if self.config.lr_scheduler == LRSchedulerChoice.ReduceLROnPlateau:
                sched_kwargs = dict(
                    mode="min"
                    if getattr(hylfm.metrics, cfg.score_metric.name).minimize
                    else "max",
                    factor=cfg.lr_sched_factor,
                    patience=cfg.lr_sched_patience,
                    threshold=cfg.lr_sched_thres,
                    threshold_mode=cfg.lr_sched_thres_mode,
                    cooldown=0,
                    min_lr=1e-7,
                )
            else:
                raise NotImplementedError

            self.lr_scheduler: LRScheduler = sched_class(
                self.optimizer, **sched_kwargs)
            if checkpoint.lr_scheduler_state_dict is not None:
                self.lr_scheduler.load_state_dict(
                    checkpoint.lr_scheduler_state_dict)

        self.validator = ValidationRun(
            config=ValidationRunConfig(
                batch_size=cfg.eval_batch_size,
                data_range=cfg.data_range,
                dataset=cfg.dataset,
                interpolation_order=cfg.interpolation_order,
                save_output_to_disk={},
                win_sigma=cfg.win_sigma,
                win_size=cfg.win_size,
                hylfm_version=cfg.hylfm_version,
                point_cloud_threshold=cfg.point_cloud_threshold,
            ),
            model=model,
            score_metric=cfg.score_metric,
            name=self.name,
        )
        self.validate_every = Period(cfg.validate_every_value,
                                     cfg.validate_every_unit)
        self.epoch_len = len(self.dataloader)

        self.best_validation_score = checkpoint.best_validation_score
        self.epoch = checkpoint.epoch
        self.impatience = checkpoint.impatience
        self.iteration = checkpoint.iteration
        self.training_run_id = checkpoint.training_run_id
        self.validation_iteration = checkpoint.validation_iteration