train( dataset=DatasetChoice.heart_dyn_refine, batch_multiplier=2, batch_size=5, crit_apply_weight_above_threshold=False, crit_beta=1.0, crit_decay_weight_by=0.97, crit_decay_weight_every_unit=PeriodUnit.iteration, crit_decay_weight_every_value=100, crit_decay_weight_limit=1.0, crit_ms_ssim_weight=0.01, crit_threshold=1.0, crit_weight=0.001, criterion=CriterionChoice.WeightedL1_MS_SSIM, data_range=1.0, eval_batch_size=5, interpolation_order=2, lr_sched_factor=0.5, lr_sched_patience=10, lr_sched_thres=0.0001, lr_sched_thres_mode=LRSchedThresMode.abs, lr_scheduler=LRSchedulerChoice.ReduceLROnPlateau, max_epochs=200, model_weights=None, # Path("/g/kreshuk/LF_computed/lnet/logs/heart2/z_out49/f4_b2/20-05-20_10-18-11/train/run000/checkpoints/v1_checkpoint_MSSSIM=0.6722144321961836.pth"), opt_lr=5e-4, opt_momentum=0.0, opt_weight_decay=0.0, optimizer=OptimizerChoice.Adam, patience=10, score_metric=MetricChoice.MS_SSIM, seed=None, validate_every_unit=PeriodUnit.iteration, validate_every_value=400, win_sigma=1.5, win_size=11, # model nnum=19, z_out=49, kernel2d=3, c00_2d=1024, c01_2d=512, c02_2d=256, c03_2d=256, c04_2d=256, up0_2d=128, c10_2d=128, c11_2d=128, c12_2d=128, c13_2d=0, c14_2d=0, up1_2d=0, c20_2d=0, c21_2d=0, c22_2d=0, c23_2d=0, c24_2d=0, up2_2d=0, c30_2d=0, c31_2d=0, c32_2d=0, c33_2d=0, c34_2d=0, last_kernel2d=1, cin_3d=16, kernel3d=3, c00_3d=16, c01_3d=16, c02_3d=0, c03_3d=0, c04_3d=0, up0_3d=8, c10_3d=8, c11_3d=4, c12_3d=0, c13_3d=0, c14_3d=0, up1_3d=0, c20_3d=0, c21_3d=0, c22_3d=0, c23_3d=0, c24_3d=0, up2_3d=0, c30_3d=0, c31_3d=0, c32_3d=0, c33_3d=0, c34_3d=0, init_fn=HyLFM_Net.InitName.xavier_uniform_, final_activation=None, )
train( dataset=DatasetChoice.beads_highc_b, batch_multiplier=2, batch_size=1, crit_apply_weight_above_threshold=False, crit_beta=1.0, crit_decay_weight_by=0.8, crit_decay_weight_every_unit=PeriodUnit.epoch, crit_decay_weight_every_value=1, crit_decay_weight_limit=1.0, crit_ms_ssim_weight=0.01, crit_threshold=0.5, crit_weight=0.001, criterion=CriterionChoice.WeightedSmoothL1, data_range=1.0, eval_batch_size=1, interpolation_order=2, lr_sched_factor=0.5, lr_sched_patience=10, lr_sched_thres=0.0001, lr_sched_thres_mode=LRSchedThresMode.abs, lr_scheduler=LRSchedulerChoice.ReduceLROnPlateau, max_epochs=10, model_weights=None, # Path() opt_lr=3e-4, opt_momentum=0.0, opt_weight_decay=0.0, optimizer=OptimizerChoice.Adam, patience=5, score_metric=MetricChoice.MS_SSIM, seed=None, validate_every_unit=PeriodUnit.epoch, validate_every_value=1, win_sigma=1.5, win_size=11, # model nnum=19, z_out=51, kernel2d=3, c00_2d=976, c01_2d=976, c02_2d=0, c03_2d=0, c04_2d=0, up0_2d=488, c10_2d=488, c11_2d=0, c12_2d=0, c13_2d=0, c14_2d=0, up1_2d=244, c20_2d=244, c21_2d=0, c22_2d=0, c23_2d=0, c24_2d=0, up2_2d=0, c30_2d=0, c31_2d=0, c32_2d=0, c33_2d=0, c34_2d=0, last_kernel2d=1, cin_3d=7, kernel3d=3, c00_3d=7, c01_3d=0, c02_3d=0, c03_3d=0, c04_3d=0, up0_3d=7, c10_3d=7, c11_3d=7, c12_3d=0, c13_3d=0, c14_3d=0, up1_3d=0, c20_3d=0, c21_3d=0, c22_3d=0, c23_3d=0, c24_3d=0, up2_3d=0, c30_3d=0, c31_3d=0, c32_3d=0, c33_3d=0, c34_3d=0, init_fn=HyLFM_Net.InitName.xavier_uniform_, final_activation=None, )
train( dataset=DatasetChoice.heart_2020_02_fish1_static_sliced, batch_multiplier=2, batch_size=5, crit_apply_weight_above_threshold=False, crit_beta=1.0, crit_decay_weight_by=0.97, crit_decay_weight_every_unit=PeriodUnit.iteration, crit_decay_weight_every_value=100, crit_decay_weight_limit=1.0, crit_ms_ssim_weight=0.01, crit_threshold=1.0, crit_weight=0.001, criterion=CriterionChoice.WeightedL1_MS_SSIM, data_range=1.0, eval_batch_size=5, interpolation_order=2, lr_sched_factor=0.5, lr_sched_patience=10, lr_sched_thres=0.0001, lr_sched_thres_mode=LRSchedThresMode.abs, lr_scheduler=LRSchedulerChoice.ReduceLROnPlateau, max_epochs=200, # model_weights=Path("/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/fake_dyn/from_scratch/21-01-12_21-05-48/train_01/run000/checkpoints/v1_checkpoint_37400_ms_ssim-scaled=0.8433429219506003.pth"), model_weights=Path( "/g/kreshuk/beuttenm/pycharm_tmp/repos/hylfm-net/logs/train/heart/fake_dyn/from_dyn_incl_dyn/21-01-13_10-02-44/train_01/run000/checkpoints/v1_checkpoint_9900_ms_ssim-scaled=0.8582265810533003.pth" ), opt_lr=5e-4, opt_momentum=0.0, opt_weight_decay=0.0, optimizer=OptimizerChoice.Adam, patience=10, score_metric=MetricChoice.MS_SSIM, seed=None, validate_every_unit=PeriodUnit.iteration, validate_every_value=400, win_sigma=1.5, win_size=11, # model nnum=19, z_out=49, kernel2d=3, c00_2d=488, c01_2d=488, c02_2d=0, c03_2d=0, c04_2d=0, up0_2d=244, c10_2d=244, c11_2d=0, c12_2d=0, c13_2d=0, c14_2d=0, up1_2d=0, c20_2d=0, c21_2d=0, c22_2d=0, c23_2d=0, c24_2d=0, up2_2d=0, c30_2d=0, c31_2d=0, c32_2d=0, c33_2d=0, c34_2d=0, last_kernel2d=1, cin_3d=7, kernel3d=3, c00_3d=7, c01_3d=0, c02_3d=0, c03_3d=0, c04_3d=0, up0_3d=7, c10_3d=7, c11_3d=7, c12_3d=0, c13_3d=0, c14_3d=0, up1_3d=0, c20_3d=0, c21_3d=0, c22_3d=0, c23_3d=0, c24_3d=0, up2_3d=0, c30_3d=0, c31_3d=0, c32_3d=0, c33_3d=0, c34_3d=0, init_fn=HyLFM_Net.InitName.xavier_uniform_, final_activation=None, )
def train_model_like( model_kwargs_from_checkpoint: Path, batch_multiplier: Optional[int] = typer.Option(None, "--batch_multiplier"), batch_size: Optional[int] = typer.Option(None, "--batch_size"), crit_decay: Optional[float] = typer.Option(None, "--crit_decay"), crit_decay_weight_every_unit: Optional[PeriodUnit] = typer.Option( None, "--crit_decay_weight_every_unit"), crit_decay_weight_every_value: Optional[int] = typer.Option( None, "--crit_decay_weight_every_value"), crit_threshold: Optional[float] = typer.Option(None, "--crit_threshold"), crit_weight: Optional[float] = typer.Option(None, "--crit_weight"), criterion: Optional[CriterionChoice] = typer.Option(None, "--criterion"), dataset: Optional[DatasetChoice] = typer.Option(None, "--dataset"), eval_batch_size: Optional[int] = typer.Option(None, "--eval_batch_size"), interpolation_order: Optional[int] = typer.Option(None, "--interpolation_order"), lr_sched_factor: Optional[float] = typer.Option(None, "--lr_sched_factor"), lr_sched_patience: Optional[int] = typer.Option(None, "--lr_sched_patience"), lr_sched_thres: Optional[float] = typer.Option(None, "--lr_sched_thres"), lr_sched_thres_mode: Optional[LRSchedThresMode] = typer.Option( None, "--lr_sched_thres_mode"), lr_scheduler: Optional[LRSchedulerChoice] = typer.Option( None, "--lr_scheduler"), max_epochs: Optional[int] = typer.Option(None, "--max_epochs"), model_weights: Optional[Path] = typer.Option(None, "--model_weights"), opt_lr: Optional[float] = typer.Option(None, "--opt_lr"), opt_momentum: Optional[float] = typer.Option(None, "--opt_momentum"), opt_weight_decay: Optional[float] = typer.Option(None, "--opt_weight_decay"), optimizer: Optional[OptimizerChoice] = typer.Option(None, "--optimizer"), patience: Optional[int] = typer.Option(None, "--patience"), save_after_validation_iterations: Optional[List[int]] = typer.Option( None, "--save_after_validation_iterations"), score_metric: MetricChoice = typer.Option(MetricChoice.MS_SSIM, "--score_metric"), seed: Optional[int] = typer.Option(None, "--seed"), validate_every_unit: Optional[PeriodUnit] = typer.Option( None, "--validate_every_unit"), validate_every_value: Optional[int] = typer.Option( None, "--validate_every_value"), win_sigma: Optional[float] = typer.Option(None, "--win_sigma"), win_size: Optional[int] = typer.Option(None, "--win_size"), z_out: Optional[int] = typer.Option(None, "--z_out"), zero_max_patience: Optional[int] = typer.Option(None, "--zero_max_patience"), ): reference_checkpoint = Checkpoint.load(model_kwargs_from_checkpoint) reference_config = reference_checkpoint.config config = reference_config.as_dict(for_logging=False) config.pop("hylfm_version") config.update(config.pop("model")) # flatten model kwargs into config changes = { k: v for k, v in { "batch_multiplier": batch_multiplier, "batch_size": batch_size, "crit_decay": crit_decay, "crit_decay_weight_every_unit": crit_decay_weight_every_unit, "crit_decay_weight_every_value": crit_decay_weight_every_value, "crit_threshold": crit_threshold, "crit_weight": crit_weight, "criterion": criterion, "dataset": dataset, "eval_batch_size": eval_batch_size, "interpolation_order": interpolation_order, "lr_sched_factor": lr_sched_factor, "lr_sched_patience": lr_sched_patience, "lr_sched_thres": lr_sched_thres, "lr_sched_thres_mode": lr_sched_thres_mode, "lr_scheduler": lr_scheduler, "max_epochs": max_epochs, "model_weights": model_weights, "opt_lr": opt_lr, "opt_momentum": opt_momentum, "opt_weight_decay": opt_weight_decay, "optimizer": optimizer, "patience": patience, "save_after_validation_iterations": save_after_validation_iterations, "score_metric": score_metric, "seed": seed, "validate_every_unit": validate_every_unit, "validate_every_value": validate_every_value, "win_sigma": win_sigma, "win_size": win_size, "z_out": z_out, "zero_max_patience": zero_max_patience, }.items() if v is not None } note = f"train like {model_kwargs_from_checkpoint.resolve()} {'with changes:' if changes else ''} " + " ".join( [f"{k}: {v}" for k, v in changes.items()]) logger.info(note) if len(note) > 500: note = note[:499] + "…" config.update(changes) train(**config, note=note)