Пример #1
0
 train(
     dataset=DatasetChoice.heart_dyn_refine,
     batch_multiplier=2,
     batch_size=5,
     crit_apply_weight_above_threshold=False,
     crit_beta=1.0,
     crit_decay_weight_by=0.97,
     crit_decay_weight_every_unit=PeriodUnit.iteration,
     crit_decay_weight_every_value=100,
     crit_decay_weight_limit=1.0,
     crit_ms_ssim_weight=0.01,
     crit_threshold=1.0,
     crit_weight=0.001,
     criterion=CriterionChoice.WeightedL1_MS_SSIM,
     data_range=1.0,
     eval_batch_size=5,
     interpolation_order=2,
     lr_sched_factor=0.5,
     lr_sched_patience=10,
     lr_sched_thres=0.0001,
     lr_sched_thres_mode=LRSchedThresMode.abs,
     lr_scheduler=LRSchedulerChoice.ReduceLROnPlateau,
     max_epochs=200,
     model_weights=None, # Path("/g/kreshuk/LF_computed/lnet/logs/heart2/z_out49/f4_b2/20-05-20_10-18-11/train/run000/checkpoints/v1_checkpoint_MSSSIM=0.6722144321961836.pth"),
     opt_lr=5e-4,
     opt_momentum=0.0,
     opt_weight_decay=0.0,
     optimizer=OptimizerChoice.Adam,
     patience=10,
     score_metric=MetricChoice.MS_SSIM,
     seed=None,
     validate_every_unit=PeriodUnit.iteration,
     validate_every_value=400,
     win_sigma=1.5,
     win_size=11,
     # model
     nnum=19,
     z_out=49,
     kernel2d=3,
     c00_2d=1024,
     c01_2d=512,
     c02_2d=256,
     c03_2d=256,
     c04_2d=256,
     up0_2d=128,
     c10_2d=128,
     c11_2d=128,
     c12_2d=128,
     c13_2d=0,
     c14_2d=0,
     up1_2d=0,
     c20_2d=0,
     c21_2d=0,
     c22_2d=0,
     c23_2d=0,
     c24_2d=0,
     up2_2d=0,
     c30_2d=0,
     c31_2d=0,
     c32_2d=0,
     c33_2d=0,
     c34_2d=0,
     last_kernel2d=1,
     cin_3d=16,
     kernel3d=3,
     c00_3d=16,
     c01_3d=16,
     c02_3d=0,
     c03_3d=0,
     c04_3d=0,
     up0_3d=8,
     c10_3d=8,
     c11_3d=4,
     c12_3d=0,
     c13_3d=0,
     c14_3d=0,
     up1_3d=0,
     c20_3d=0,
     c21_3d=0,
     c22_3d=0,
     c23_3d=0,
     c24_3d=0,
     up2_3d=0,
     c30_3d=0,
     c31_3d=0,
     c32_3d=0,
     c33_3d=0,
     c34_3d=0,
     init_fn=HyLFM_Net.InitName.xavier_uniform_,
     final_activation=None,
 )
Пример #2
0
 train(
     dataset=DatasetChoice.beads_highc_b,
     batch_multiplier=2,
     batch_size=1,
     crit_apply_weight_above_threshold=False,
     crit_beta=1.0,
     crit_decay_weight_by=0.8,
     crit_decay_weight_every_unit=PeriodUnit.epoch,
     crit_decay_weight_every_value=1,
     crit_decay_weight_limit=1.0,
     crit_ms_ssim_weight=0.01,
     crit_threshold=0.5,
     crit_weight=0.001,
     criterion=CriterionChoice.WeightedSmoothL1,
     data_range=1.0,
     eval_batch_size=1,
     interpolation_order=2,
     lr_sched_factor=0.5,
     lr_sched_patience=10,
     lr_sched_thres=0.0001,
     lr_sched_thres_mode=LRSchedThresMode.abs,
     lr_scheduler=LRSchedulerChoice.ReduceLROnPlateau,
     max_epochs=10,
     model_weights=None,  # Path()
     opt_lr=3e-4,
     opt_momentum=0.0,
     opt_weight_decay=0.0,
     optimizer=OptimizerChoice.Adam,
     patience=5,
     score_metric=MetricChoice.MS_SSIM,
     seed=None,
     validate_every_unit=PeriodUnit.epoch,
     validate_every_value=1,
     win_sigma=1.5,
     win_size=11,
     # model
     nnum=19,
     z_out=51,
     kernel2d=3,
     c00_2d=976,
     c01_2d=976,
     c02_2d=0,
     c03_2d=0,
     c04_2d=0,
     up0_2d=488,
     c10_2d=488,
     c11_2d=0,
     c12_2d=0,
     c13_2d=0,
     c14_2d=0,
     up1_2d=244,
     c20_2d=244,
     c21_2d=0,
     c22_2d=0,
     c23_2d=0,
     c24_2d=0,
     up2_2d=0,
     c30_2d=0,
     c31_2d=0,
     c32_2d=0,
     c33_2d=0,
     c34_2d=0,
     last_kernel2d=1,
     cin_3d=7,
     kernel3d=3,
     c00_3d=7,
     c01_3d=0,
     c02_3d=0,
     c03_3d=0,
     c04_3d=0,
     up0_3d=7,
     c10_3d=7,
     c11_3d=7,
     c12_3d=0,
     c13_3d=0,
     c14_3d=0,
     up1_3d=0,
     c20_3d=0,
     c21_3d=0,
     c22_3d=0,
     c23_3d=0,
     c24_3d=0,
     up2_3d=0,
     c30_3d=0,
     c31_3d=0,
     c32_3d=0,
     c33_3d=0,
     c34_3d=0,
     init_fn=HyLFM_Net.InitName.xavier_uniform_,
     final_activation=None,
 )
Пример #3
0
 train(
     dataset=DatasetChoice.heart_2020_02_fish1_static_sliced,
     batch_multiplier=2,
     batch_size=5,
     crit_apply_weight_above_threshold=False,
     crit_beta=1.0,
     crit_decay_weight_by=0.97,
     crit_decay_weight_every_unit=PeriodUnit.iteration,
     crit_decay_weight_every_value=100,
     crit_decay_weight_limit=1.0,
     crit_ms_ssim_weight=0.01,
     crit_threshold=1.0,
     crit_weight=0.001,
     criterion=CriterionChoice.WeightedL1_MS_SSIM,
     data_range=1.0,
     eval_batch_size=5,
     interpolation_order=2,
     lr_sched_factor=0.5,
     lr_sched_patience=10,
     lr_sched_thres=0.0001,
     lr_sched_thres_mode=LRSchedThresMode.abs,
     lr_scheduler=LRSchedulerChoice.ReduceLROnPlateau,
     max_epochs=200,
     # model_weights=Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/fake_dyn/from_scratch/21-01-12_21-05-48/train_01/run000/checkpoints/v1_checkpoint_37400_ms_ssim-scaled=0.8433429219506003.pth"),
     model_weights=Path(
         "/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/fake_dyn/from_dyn_incl_dyn/21-01-13_10-02-44/train_01/run000/checkpoints/v1_checkpoint_9900_ms_ssim-scaled=0.8582265810533003.pth"
     ),
     opt_lr=5e-4,
     opt_momentum=0.0,
     opt_weight_decay=0.0,
     optimizer=OptimizerChoice.Adam,
     patience=10,
     score_metric=MetricChoice.MS_SSIM,
     seed=None,
     validate_every_unit=PeriodUnit.iteration,
     validate_every_value=400,
     win_sigma=1.5,
     win_size=11,
     # model
     nnum=19,
     z_out=49,
     kernel2d=3,
     c00_2d=488,
     c01_2d=488,
     c02_2d=0,
     c03_2d=0,
     c04_2d=0,
     up0_2d=244,
     c10_2d=244,
     c11_2d=0,
     c12_2d=0,
     c13_2d=0,
     c14_2d=0,
     up1_2d=0,
     c20_2d=0,
     c21_2d=0,
     c22_2d=0,
     c23_2d=0,
     c24_2d=0,
     up2_2d=0,
     c30_2d=0,
     c31_2d=0,
     c32_2d=0,
     c33_2d=0,
     c34_2d=0,
     last_kernel2d=1,
     cin_3d=7,
     kernel3d=3,
     c00_3d=7,
     c01_3d=0,
     c02_3d=0,
     c03_3d=0,
     c04_3d=0,
     up0_3d=7,
     c10_3d=7,
     c11_3d=7,
     c12_3d=0,
     c13_3d=0,
     c14_3d=0,
     up1_3d=0,
     c20_3d=0,
     c21_3d=0,
     c22_3d=0,
     c23_3d=0,
     c24_3d=0,
     up2_3d=0,
     c30_3d=0,
     c31_3d=0,
     c32_3d=0,
     c33_3d=0,
     c34_3d=0,
     init_fn=HyLFM_Net.InitName.xavier_uniform_,
     final_activation=None,
 )
Пример #4
0
def train_model_like(
    model_kwargs_from_checkpoint: Path,
    batch_multiplier: Optional[int] = typer.Option(None, "--batch_multiplier"),
    batch_size: Optional[int] = typer.Option(None, "--batch_size"),
    crit_decay: Optional[float] = typer.Option(None, "--crit_decay"),
    crit_decay_weight_every_unit: Optional[PeriodUnit] = typer.Option(
        None, "--crit_decay_weight_every_unit"),
    crit_decay_weight_every_value: Optional[int] = typer.Option(
        None, "--crit_decay_weight_every_value"),
    crit_threshold: Optional[float] = typer.Option(None, "--crit_threshold"),
    crit_weight: Optional[float] = typer.Option(None, "--crit_weight"),
    criterion: Optional[CriterionChoice] = typer.Option(None, "--criterion"),
    dataset: Optional[DatasetChoice] = typer.Option(None, "--dataset"),
    eval_batch_size: Optional[int] = typer.Option(None, "--eval_batch_size"),
    interpolation_order: Optional[int] = typer.Option(None,
                                                      "--interpolation_order"),
    lr_sched_factor: Optional[float] = typer.Option(None, "--lr_sched_factor"),
    lr_sched_patience: Optional[int] = typer.Option(None,
                                                    "--lr_sched_patience"),
    lr_sched_thres: Optional[float] = typer.Option(None, "--lr_sched_thres"),
    lr_sched_thres_mode: Optional[LRSchedThresMode] = typer.Option(
        None, "--lr_sched_thres_mode"),
    lr_scheduler: Optional[LRSchedulerChoice] = typer.Option(
        None, "--lr_scheduler"),
    max_epochs: Optional[int] = typer.Option(None, "--max_epochs"),
    model_weights: Optional[Path] = typer.Option(None, "--model_weights"),
    opt_lr: Optional[float] = typer.Option(None, "--opt_lr"),
    opt_momentum: Optional[float] = typer.Option(None, "--opt_momentum"),
    opt_weight_decay: Optional[float] = typer.Option(None,
                                                     "--opt_weight_decay"),
    optimizer: Optional[OptimizerChoice] = typer.Option(None, "--optimizer"),
    patience: Optional[int] = typer.Option(None, "--patience"),
    save_after_validation_iterations: Optional[List[int]] = typer.Option(
        None, "--save_after_validation_iterations"),
    score_metric: MetricChoice = typer.Option(MetricChoice.MS_SSIM,
                                              "--score_metric"),
    seed: Optional[int] = typer.Option(None, "--seed"),
    validate_every_unit: Optional[PeriodUnit] = typer.Option(
        None, "--validate_every_unit"),
    validate_every_value: Optional[int] = typer.Option(
        None, "--validate_every_value"),
    win_sigma: Optional[float] = typer.Option(None, "--win_sigma"),
    win_size: Optional[int] = typer.Option(None, "--win_size"),
    z_out: Optional[int] = typer.Option(None, "--z_out"),
    zero_max_patience: Optional[int] = typer.Option(None,
                                                    "--zero_max_patience"),
):
    reference_checkpoint = Checkpoint.load(model_kwargs_from_checkpoint)
    reference_config = reference_checkpoint.config
    config = reference_config.as_dict(for_logging=False)
    config.pop("hylfm_version")
    config.update(config.pop("model"))  # flatten model kwargs into config

    changes = {
        k: v
        for k, v in {
            "batch_multiplier": batch_multiplier,
            "batch_size": batch_size,
            "crit_decay": crit_decay,
            "crit_decay_weight_every_unit": crit_decay_weight_every_unit,
            "crit_decay_weight_every_value": crit_decay_weight_every_value,
            "crit_threshold": crit_threshold,
            "crit_weight": crit_weight,
            "criterion": criterion,
            "dataset": dataset,
            "eval_batch_size": eval_batch_size,
            "interpolation_order": interpolation_order,
            "lr_sched_factor": lr_sched_factor,
            "lr_sched_patience": lr_sched_patience,
            "lr_sched_thres": lr_sched_thres,
            "lr_sched_thres_mode": lr_sched_thres_mode,
            "lr_scheduler": lr_scheduler,
            "max_epochs": max_epochs,
            "model_weights": model_weights,
            "opt_lr": opt_lr,
            "opt_momentum": opt_momentum,
            "opt_weight_decay": opt_weight_decay,
            "optimizer": optimizer,
            "patience": patience,
            "save_after_validation_iterations":
            save_after_validation_iterations,
            "score_metric": score_metric,
            "seed": seed,
            "validate_every_unit": validate_every_unit,
            "validate_every_value": validate_every_value,
            "win_sigma": win_sigma,
            "win_size": win_size,
            "z_out": z_out,
            "zero_max_patience": zero_max_patience,
        }.items() if v is not None
    }

    note = f"train like {model_kwargs_from_checkpoint.resolve()} {'with changes:' if changes else ''} " + " ".join(
        [f"{k}: {v}" for k, v in changes.items()])
    logger.info(note)
    if len(note) > 500:
        note = note[:499] + "…"

    config.update(changes)
    train(**config, note=note)