def test_Timer_callback(): runner = Runner( model=TestModel, optimizer=TestOptimizer, criterion=TestCriterion, metrics=TestMetric, callbacks=pt_clb.Timer(), ) runner.fit(TestLoader, epochs=2)
def test_rank_zero_only(): """check that decorator disables come callbacks""" os.environ["RANK"] = "0" # check that wrapping instance work timer = pt_clb.rank_zero_only(pt_clb.Timer()) assert hasattr(timer, "timer") os.environ["RANK"] = "1" # check that wrapping class also works timer = pt_clb.rank_zero_only(pt_clb.Timer)() assert not hasattr(timer, "timer")
criterion=TEST_CRITERION, metrics=TEST_METRIC, gradient_clip_val=1.0 ) runner.fit(TEST_LOADER, epochs=2) # We only test that callbacks don't crash NOT that they do what they should do TMP_PATH = "/tmp/pt_tools2/" os.makedirs(TMP_PATH, exist_ok=True) @pytest.mark.parametrize( "callback", [ pt_clb.Timer(), pt_clb.ReduceLROnPlateau(), pt_clb.CheckpointSaver(TMP_PATH, save_name="model.chpn"), pt_clb.CheckpointSaver( TMP_PATH, save_name="model.chpn", monitor=TEST_METRIC.name, mode="max" ), pt_clb.TensorBoard(log_dir=TMP_PATH), pt_clb.TensorBoardWithCM(log_dir=TMP_PATH), pt_clb.ConsoleLogger(), pt_clb.FileLogger(TMP_PATH), pt_clb.Mixup(0.2, NUM_CLASSES), pt_clb.Cutmix(1.0, NUM_CLASSES), pt_clb.ScheduledDropout(), ], ) def test_callback(callback):
def main(): FLAGS = parse_args() print(FLAGS) pt.utils.misc.set_random_seed(42) # fix all seeds ## dump config os.makedirs(FLAGS.outdir, exist_ok=True) yaml.dump(vars(FLAGS), open(FLAGS.outdir + '/config.yaml', 'w')) ## get dataloaders if FLAGS.train_tta: FLAGS.bs //= 4 # account for later augmentations to avoid OOM train_dtld, val_dtld = get_dataloaders(FLAGS.datasets, FLAGS.augmentation, FLAGS.bs, FLAGS.size, FLAGS.val_size, FLAGS.buildings_only) ## get model and optimizer model = MODEL_FROM_NAME[FLAGS.segm_arch](FLAGS.arch, **FLAGS.model_params).cuda() if FLAGS.train_tta: # idea from https://arxiv.org/pdf/2002.09024.pdf paper model = pt.tta_wrapper.TTA(model, segm=True, h_flip=True, rotation=[90], merge="max") model.encoder = model.model.encoder model.decoder = model.model.decoder optimizer = optimizer_from_name(FLAGS.optim)( model.parameters(), lr=FLAGS.lr, weight_decay=FLAGS. weight_decay, # **FLAGS.optim_params TODO: add additional optim params if needed ) if FLAGS.lookahead: optimizer = pt.optim.Lookahead(optimizer) if FLAGS.resume: checkpoint = torch.load( FLAGS.resume, map_location=lambda storage, loc: storage.cuda()) model.load_state_dict(checkpoint["state_dict"], strict=False) num_params = pt.utils.misc.count_parameters(model)[0] print(f"Number of parameters: {num_params / 1e6:.02f}M") ## train on fp16 by default model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1", verbosity=0, loss_scale=1024) ## get loss. fixed for now. bce_loss = TargetWrapper( pt.losses.CrossEntropyLoss(mode="binary").cuda(), "mask") bce_loss.name = "BCE" loss = criterion_from_list(FLAGS.criterion).cuda() # loss = 0.5 * pt.losses.CrossEntropyLoss(mode="binary", weight=[5]).cuda() print("Loss for this run is: ", loss) ## get runner sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(FLAGS.phases) runner = pt.fit_wrapper.Runner( model, optimizer, criterion=loss, callbacks=[ pt_clb.Timer(), pt_clb.ConsoleLogger(), pt_clb.FileLogger(FLAGS.outdir), pt_clb.SegmCutmix(1, 1) if FLAGS.cutmix else NoClb(), pt_clb.CheckpointSaver(FLAGS.outdir, save_name="model.chpn"), sheduler, PredictViewer(FLAGS.outdir, num_images=8), ScheduledDropout(FLAGS.dropout, FLAGS.dropout_epochs) if FLAGS.dropout else NoClb() ], metrics=[ bce_loss, TargetWrapper( pt.metrics.JaccardScore(mode="binary").cuda(), "mask"), TargetWrapper(ThrJaccardScore(thr=0.5), "mask"), TargetWrapper(BalancedAccuracy2(balanced=False), "mask"), ], ) if FLAGS.decoder_warmup_epochs > 0: ## freeze encoder for p in model.encoder.parameters(): p.requires_grad = False runner.fit( train_dtld, val_loader=val_dtld, epochs=FLAGS.decoder_warmup_epochs, steps_per_epoch=50 if FLAGS.short_epoch else None, val_steps=50 if FLAGS.short_epoch else None, ) ## unfreeze all for p in model.parameters(): p.requires_grad = True # need to init again to avoid nan's in loss optimizer = optimizer_from_name(FLAGS.optim)( model.parameters(), lr=FLAGS.lr, weight_decay=FLAGS. weight_decay, # **FLAGS.optim_params TODO: add additional optim params if needed ) model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1", verbosity=0, loss_scale=2048) runner.state.model = model runner.state.optimizer = optimizer runner.fit( train_dtld, val_loader=val_dtld, start_epoch=FLAGS.decoder_warmup_epochs, epochs=sheduler.tot_epochs, steps_per_epoch=50 if FLAGS.short_epoch else None, val_steps=50 if FLAGS.short_epoch else None, )
def main(): # Get config for this run hparams = parse_args() # Setup logger config = { "handlers": [ { "sink": sys.stdout, "format": "{time:[MM-DD HH:mm]} - {message}" }, { "sink": f"{hparams.outdir}/logs.txt", "format": "{time:[MM-DD HH:mm]} - {message}" }, ], } logger.configure(**config) logger.info(f"Parameters used for training: {hparams}") # Fix seeds for reprodusability pt.utils.misc.set_random_seed(hparams.seed) # Save config os.makedirs(hparams.outdir, exist_ok=True) yaml.dump(vars(hparams), open(hparams.outdir + "/config.yaml", "w")) # Get model model = Model(arch=hparams.arch, model_params=hparams.model_params, embedding_size=hparams.embedding_size, pooling=hparams.pooling).cuda() # Get loss # loss = LOSS_FROM_NAME[hparams.criterion](in_features=hparams.embedding_size, **hparams.criterion_params).cuda() loss = LOSS_FROM_NAME["cross_entropy"].cuda() logger.info(f"Loss for this run is: {loss}") if hparams.resume: checkpoint = torch.load( hparams.resume, map_location=lambda storage, loc: storage.cuda()) model.load_state_dict(checkpoint["state_dict"], strict=True) loss.load_state_dict(checkpoint["loss"], strict=True) if hparams.freeze_bn: freeze_batch_norm(model) # Get optimizer # optim_params = pt.utils.misc.filter_bn_from_wd(model) optim_params = list(loss.parameters()) + list( model.parameters()) # add loss params optimizer = optimizer_from_name(hparams.optim)( optim_params, lr=0, weight_decay=hparams.weight_decay, amsgrad=True) num_params = pt.utils.misc.count_parameters(model)[0] logger.info(f"Model size: {num_params / 1e6:.02f}M") # logger.info(model) # Scheduler is an advanced way of planning experiment sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(hparams.phases) # Save logs TB_callback = pt_clb.TensorBoard(hparams.outdir, log_every=20) # Get dataloaders train_loader, val_loader, val_indexes = get_dataloaders( root=hparams.root, augmentation=hparams.augmentation, size=hparams.size, val_size=hparams.val_size, batch_size=hparams.batch_size, workers=hparams.workers, ) # Load validation query / gallery split and resort it according to indexes from sampler df_val = pd.read_csv(os.path.join(hparams.root, "train_val.csv")) df_val = df_val[df_val["is_train"].astype(np.bool) == False] val_is_query = df_val.is_query.values[val_indexes].astype(np.bool) logger.info(f"Start training") # Init runner runner = pt.fit_wrapper.Runner( model, optimizer, criterion=loss, callbacks=[ # pt_clb.BatchMetrics([pt.metrics.Accuracy(topk=1)]), ContestMetricsCallback( is_query=val_is_query[:1280] if hparams.debug else val_is_query ), pt_clb.Timer(), pt_clb.ConsoleLogger(), pt_clb.FileLogger(), TB_callback, CheckpointSaver(hparams.outdir, save_name="model.chpn", monitor="target", mode="max"), CheckpointSaver(hparams.outdir, save_name="model_mapr.chpn", monitor="mAP@R", mode="max"), CheckpointSaver(hparams.outdir, save_name="model_loss.chpn"), sheduler, # EMA must go after other checkpoints pt_clb.ModelEma(model, hparams.ema_decay) if hparams.ema_decay else pt_clb.Callback(), ], use_fp16=hparams. use_fp16, # use mixed precision by default. # hparams.opt_level != "O0", ) if hparams.head_warmup_epochs > 0: #Freeze model for p in model.parameters(): p.requires_grad = False runner.fit( train_loader, # val_loader=val_loader, epochs=hparams.head_warmup_epochs, steps_per_epoch=20 if hparams.debug else None, # val_steps=20 if hparams.debug else None, ) # Unfreeze model for p in model.parameters(): p.requires_grad = True if hparams.freeze_bn: freeze_batch_norm(model) # Re-init to avoid nan's in loss optim_params = list(loss.parameters()) + list(model.parameters()) optimizer = optimizer_from_name(hparams.optim)( optim_params, lr=0, weight_decay=hparams.weight_decay, amsgrad=True) runner.state.model = model runner.state.optimizer = optimizer runner.state.criterion = loss # Train runner.fit( train_loader, # val_loader=val_loader, start_epoch=hparams.head_warmup_epochs, epochs=sheduler.tot_epochs, steps_per_epoch=20 if hparams.debug else None, # val_steps=20 if hparams.debug else None, ) logger.info(f"Loading best model") checkpoint = torch.load(os.path.join(hparams.outdir, f"model.chpn")) model.load_state_dict(checkpoint["state_dict"], strict=True) # runner.state.model = model # loss.load_state_dict(checkpoint["loss"], strict=True) # Evaluate _, [acc1, map10, target, mapR] = runner.evaluate( val_loader, steps=20 if hparams.debug else None, ) logger.info( f"Val: Acc@1 {acc1:0.5f}, mAP@10 {map10:0.5f}, Target {target:0.5f}, mAP@R {mapR:0.5f}" ) # Save params used for training and final metrics into separate TensorBoard file metric_dict = { "hparam/Acc@1": acc1, "hparam/mAP@10": map10, "hparam/mAP@R": target, "hparam/Target": mapR, } # Convert all lists / dicts to avoid TB error hparams.phases hparams.phases = str(hparams.phases) hparams.model_params = str(hparams.model_params) hparams.criterion_params = str(hparams.criterion_params) with pt.utils.tensorboard.CorrectedSummaryWriter(hparams.outdir) as writer: writer.add_hparams(hparam_dict=vars(hparams), metric_dict=metric_dict)
def main(): # Get config for this run hparams = parse_args() # Setup logger config = { "handlers": [ { "sink": sys.stdout, "format": "{time:[MM-DD HH:mm]} - {message}" }, { "sink": f"{hparams.outdir}/logs.txt", "format": "{time:[MM-DD HH:mm]} - {message}" }, ], } logger.configure(**config) # Use print instead of logger to have alphabetic order. logger.info(f"Parameters used for training: {vars(hparams)}") # Fix all seeds for reprodusability pt.utils.misc.set_random_seed(hparams.seed) # Save config os.makedirs(hparams.outdir, exist_ok=True) yaml.dump(vars(hparams), open(hparams.outdir + "/config.yaml", "w")) # Get models and optimizers model = MODEL_FROM_NAME[hparams.model](**hparams.model_params).cuda() logger.info( f"Model size: {pt.utils.misc.count_parameters(model)[0] / 1e6:.02f}M") optimizer = torch.optim.Adam(model.parameters(), weight_decay=hparams.weight_decay, amsgrad=True) # Get LR from phases later # Get loss loss = criterion_from_list(hparams.criterion).cuda() # Init per-image metrics and add names metrics = metrics_from_list(hparams.metrics, reduction='mean') logger.info(f"Metrics: {[m.name for m in metrics]}") # Init feature metrics and add names feature_metrics = [] feature_extractor = "vgg16" for name in hparams.feature_metrics: metric = copy.copy(METRIC_FROM_NAME[name]) metric.name = f"{name}_{feature_extractor}" feature_metrics.append(metric) # Scheduler is an advanced way of planning experiment sheduler = pt_clb.PhasesScheduler(hparams.phases) save_name = "model_{monitor}.chpn" # Init train loop runner = pt.fit_wrapper.Runner( model=model, optimizer=optimizer, criterion=loss, callbacks=[ pt_clb.Timer(), clb.FeatureLoaderMetrics(metrics=feature_metrics, feature_extractor="vgg16"), pt_clb.BatchMetrics(metrics=metrics), clb.ConsoleLogger(metrics=["ssim", "psnr"]), clb.TensorBoard(hparams.outdir, log_every=40, num_images=2), # List of CheckpointSavers, one per metric clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='loss', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='psnr', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='ssim', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='ms-ssim', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='gmsd', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='ms-gmsd', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='ms-gmsdc', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='fsim', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='fsimc', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='vsi', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='mdsi', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='vifp', mode='max', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='content_vgg16_ap', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='style_vgg16', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='lpips', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='dists', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='brisque', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='is_metric_vgg16', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='is_vgg16', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='kid_vgg16', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='fid_vgg16', mode='min', verbose=False), clb.CheckpointSaver(hparams.outdir, save_name=save_name, monitor='msid_vgg16', mode='min', verbose=False), sheduler, ], ) # Get dataloaders transform = get_aug(aug_type=hparams.aug_type, task=hparams.task, dataset=hparams.train_dataset, size=hparams.size) train_loader = get_dataloader(dataset=hparams.train_dataset, train=True, transform=transform, batch_size=hparams.batch_size) transform = get_aug(aug_type="val", task=hparams.task, dataset=hparams.val_dataset, size=hparams.size) val_loader = get_dataloader(dataset=hparams.val_dataset, train=False, transform=transform, batch_size=hparams.batch_size) # Train runner.fit( train_loader, epochs=sheduler.tot_epochs, val_loader=val_loader, steps_per_epoch=2 if hparams.debug else None, val_steps=2 if hparams.debug else None, ) logger.info("Finished training!")
def main(): ## get config for this run FLAGS = parse_args() os.makedirs(FLAGS.outdir, exist_ok=True) config = { "handlers": [ { "sink": sys.stdout, "format": "{time:[MM-DD HH:mm:ss]} - {message}" }, { "sink": f"{FLAGS.outdir}/logs.txt", "format": "{time:[MM-DD HH:mm:ss]} - {message}" }, ], } if FLAGS.is_master: logger.configure(**config) ## dump config and diff for reproducibility yaml.dump(vars(FLAGS), open(FLAGS.outdir + "/config.yaml", "w")) kwargs = {"universal_newlines": True, "stdout": subprocess.PIPE} with open(FLAGS.outdir + "/commit_hash.txt", "w") as fp: fp.write( subprocess.run(["git", "rev-parse", "--short", "HEAD"], **kwargs).stdout) with open(FLAGS.outdir + "/diff.txt", "w") as fp: fp.write(subprocess.run(["git", "diff"], **kwargs).stdout) else: logger.configure(handlers=[]) logger.info(FLAGS) ## makes it slightly faster cudnn.benchmark = True if FLAGS.deterministic: pt.utils.misc.set_random_seed(42) # fix all seeds ## setup distributed if FLAGS.distributed: logger.info("Distributed initializing process group") torch.cuda.set_device(FLAGS.local_rank) dist.init_process_group(backend="nccl", init_method="env://", world_size=FLAGS.world_size) ## get dataloaders train_loader = DaliLoader(True, FLAGS.batch_size, FLAGS.workers, FLAGS.size) val_loader = DaliLoader(False, FLAGS.batch_size, FLAGS.workers, FLAGS.size) ## get model logger.info(f"=> Creating model '{FLAGS.arch}'") model = det_models.__dict__[FLAGS.arch](**FLAGS.model_params) if FLAGS.weight_standardization: model = pt.modules.weight_standartization.conv_to_ws_conv(model) model = model.cuda() ## get optimizer # want to filter BN from weight decay by default. It never hurts optim_params = pt.utils.misc.filter_bn_from_wd(model) # start with 0 lr. Scheduler will change this later optimizer = optimizer_from_name(FLAGS.optim)( optim_params, lr=0, weight_decay=FLAGS.weight_decay, **FLAGS.optim_params) if FLAGS.lookahead: optimizer = pt.optim.Lookahead(optimizer, la_alpha=0.5) ## load weights from previous run if given if FLAGS.resume: checkpoint = torch.load( FLAGS.resume, map_location=lambda s, loc: s.cuda()) # map for multi-gpu model.load_state_dict(checkpoint["state_dict"]) # strict=False FLAGS.start_epoch = checkpoint["epoch"] try: optimizer.load_state_dict(checkpoint["optimizer"]) except: # may raise an error if another optimzer was used or no optimizer in state dict logger.info("Failed to load state dict into optimizer") # Important to create EMA Callback after cuda() and AMP but before DDP wrapper ema_clb = pt_clb.ModelEma( model, FLAGS.ema_decay) if FLAGS.ema_decay else NoClbk() if FLAGS.distributed: model = DDP(model, delay_allreduce=True) ## define loss function (criterion) anchors = pt.utils.box.generate_anchors_boxes(FLAGS.size)[0] # script loss to lower memory consumption and make it faster # as of 1.5 it does run but loss doesn't decrease for some reason # FIXME: uncomment after 1.6 criterion = torch.jit.script(DetectionLoss(anchors).cuda()) # criterion = DetectionLoss(anchors).cuda() ## load COCO (needed for evaluation) val_coco_api = COCO("data/annotations/instances_val2017.json") model_saver = (pt_clb.CheckpointSaver(FLAGS.outdir, save_name="model.chpn") if FLAGS.is_master else NoClbk()) sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(FLAGS.phases) # common callbacks callbacks = [ pt_clb.StateReduce(), # MUST go first sheduler, pt_clb.Mixup(FLAGS.mixup, 1000) if FLAGS.mixup else NoClbk(), pt_clb.Cutmix(FLAGS.cutmix, 1000) if FLAGS.cutmix else NoClbk(), model_saver, # need to have CheckpointSaver before EMA so moving it here ema_clb, # ModelEMA MUST go after checkpoint saver to work, otherwise it would save main model instead of EMA CocoEvalClbTB(FLAGS.outdir, val_coco_api, anchors), ] if FLAGS.is_master: # callback for master process master_callbacks = [ pt_clb.Timer(), pt_clb.ConsoleLogger(), pt_clb.FileLogger(FLAGS.outdir, logger=logger), ] callbacks.extend(master_callbacks) runner = pt.fit_wrapper.Runner( model, optimizer, criterion, # metrics=[pt.metrics.Accuracy(), pt.metrics.Accuracy(5)], callbacks=callbacks, use_fp16=FLAGS.opt_level != "O0", ) if FLAGS.evaluate: return None, (42, 42) return runner.evaluate(val_loader) runner.fit( train_loader, steps_per_epoch=(None, 10)[FLAGS.short_epoch], val_loader=val_loader, # val_steps=(None, 20)[FLAGS.short_epoch], epochs=sheduler.tot_epochs, # start_epoch=FLAGS.start_epoch, # TODO: maybe want to continue from epoch ) # TODO: maybe return best loss? return runner.state.val_loss.avg, ( 0, 0) # [m.avg for m in runner.state.val_metrics]
def main(): # Setup logger config = { "handlers": [ {"sink": sys.stdout, "format": "{time:[MM-DD HH:mm:ss]} - {message}"}, {"sink": f"{hparams.outdir}/logs.txt", "format": "{time:[MM-DD HH:mm:ss]} - {message}"}, ], } logger.configure(**config) # Get config for this run hparams = parse_args() logger.info(f"Parameters used for training: {hparams}") # Fix seeds for reprodusability pt.utils.misc.set_random_seed(hparams.seed) ## Save config and Git diff (don't know how to do it without subprocess) os.makedirs(hparams.outdir, exist_ok=True) yaml.dump(vars(hparams), open(hparams.outdir + '/config.yaml', 'w')) kwargs = {"universal_newlines": True, "stdout": subprocess.PIPE} with open(hparams.outdir + '/commit_hash.txt', 'w') as f: f.write(subprocess.run(["git", "rev-parse", "--short", "HEAD"], **kwargs).stdout) with open(hparams.outdir + '/diff.txt', 'w') as f: f.write(subprocess.run(["git", "diff"], **kwargs).stdout) ## Get dataloaders train_loader, val_loader = get_dataloaders( root=hparams.root, augmentation=hparams.augmentation, fold=hparams.fold, pos_weight=hparams.pos_weight, batch_size=hparams.batch_size, size=hparams.size, val_size=hparams.val_size, workers=hparams.workers ) # Get model and optimizer model = MODEL_FROM_NAME[hparams.segm_arch](hparams.backbone, **hparams.model_params).cuda() optimizer = optimizer_from_name(hparams.optim)( model.parameters(), # Get LR from phases later weight_decay=hparams.weight_decay ) # Convert all Conv2D -> WS_Conv2d if needed if hparams.ws: model = pt.modules.weight_standartization.conv_to_ws_conv(model).cuda() # Load weights if needed if hparams.resume: checkpoint = torch.load(hparams.resume, map_location=lambda storage, loc: storage.cuda()) model.load_state_dict(checkpoint["state_dict"], strict=False) num_params = pt.utils.misc.count_parameters(model)[0] logger.info(f"Model size: {num_params / 1e6:.02f}M") ## Use AMP model, optimizer = apex.amp.initialize( model, optimizer, opt_level=hparams.opt_level, verbosity=0, loss_scale=1024 ) # Get loss loss = criterion_from_list(hparams.criterion).cuda() logger.info(f"Loss for this run is: {loss}") bce_loss = pt.losses.CrossEntropyLoss(mode="binary").cuda() # Used as a metric bce_loss.name = "BCE" # Scheduler is an advanced way of planning experiment sheduler = pt.fit_wrapper.callbacks.PhasesScheduler(hparams.phases) # Init runner runner = pt.fit_wrapper.Runner( model, optimizer, criterion=loss, callbacks=[ pt_clb.Timer(), pt_clb.ConsoleLogger(), pt_clb.FileLogger(hparams.outdir, logger=logger), pt_clb.CheckpointSaver(hparams.outdir, save_name="model.chpn"), sheduler, PredictViewer(hparams.outdir, num_images=4) ], metrics=[ bce_loss, pt.metrics.JaccardScore(mode="binary").cuda(), # ThrJaccardScore(thr=0.5), ], ) if hparams.decoder_warmup_epochs > 0: # Freeze encoder for p in model.encoder.parameters(): p.requires_grad = False runner.fit( train_loader, val_loader=val_loader, epochs=hparams.decoder_warmup_epochs, steps_per_epoch=10 if hparams.debug else None, val_steps=10 if hparams.debug else None, # val_steps=50 if hparams.debug else None, ) # Unfreeze all for p in model.parameters(): p.requires_grad = True # Reinit again to avoid NaN's in loss optimizer = optimizer_from_name(hparams.optim)( model.parameters(), weight_decay=hparams.weight_decay ) model, optimizer = apex.amp.initialize( model, optimizer, opt_level=hparams.opt_level, verbosity=0, loss_scale=2048 ) runner.state.model = model runner.state.optimizer = optimizer # Train both encoder and decoder runner.fit( train_loader, val_loader=val_loader, start_epoch=hparams.decoder_warmup_epochs, epochs=sheduler.tot_epochs, steps_per_epoch=10 if hparams.debug else None, val_steps=10 if hparams.debug else None, )