def validate(loader_valid,model, writer): print("\nValidation step...") segm_criterion = torch.nn.MSELoss() model.eval() with torch.no_grad(): reg_met = dict(pred=[],trgt=[]) # det_metrics = dict(precision=[], recall=[], f1=[]) loss_segm = [] for j, (img, mask) in enumerate(loader_valid): # move data to device img = img.cuda().float() mask = mask.cuda().unsqueeze(1).float() # run inference out = model(img) # compute losses segm_loss = segm_criterion(out, mask) counts_pred = out.sum(dim=[1,2,3]).detach().cpu().numpy() / 100. counts_gt = mask.sum(dim=[1,2,3]).detach().cpu().numpy() / 100. reg_met["pred"].extend(counts_pred) reg_met["trgt"].extend(counts_gt) loss_segm.append(segm_loss.item()) # writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j) # writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j) # log print(f"\rValid epoch ({j+1}/{len(loader_valid)})",end="", flush=True) #losses_val.append(np.mean(loss_segm)) cae, mae, mse = compute_reg_metrics(reg_met) qk, mcc, acc = compute_cls_metrics(reg_met) # writer.add_scalar("metrics/cae", cae, epoch) # writer.add_scalar("metrics/mae", mae, epoch) # writer.add_scalar("metrics/mse", mse, epoch) # writer.add_scalar("metrics/qkappa", qk, epoch) # writer.add_scalar("metrics/mcc", mcc, epoch) # writer.add_scalar("metrics/accuracy",acc,epoch) metrics = dict( cae=cae, mae=mae, mse=mse, qk=qk, mcc=mcc, acc=acc ) return model, losses_tr, optimizer, epoch, metrics, loss_segm
reg_met["pred"].extend(counts_pred) reg_met["trgt"].extend(counts_gt) # store losses losses_val["segment"].append(segm_loss.item()) losses_val["conserv"].append(conservation_loss.item()) #writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j) #writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j) # log print( f"\rValid epoch {epoch + 1}/{EPOCHS + start_ep} ({j+1}/{len(loader_valid)}) loss:{loss.item():.4f}|segm_loss:{segm_loss.item():.2f} |cons_loss: {conservation_loss.item():.2f}", end="", flush=True) cae, mae, mse = compute_reg_metrics(reg_met) qk, mcc, acc = compute_cls_metrics(reg_met) #writer.add_scalar("metrics/cae", cae, epoch) #writer.add_scalar("metrics/mae", mae, epoch) #writer.add_scalar("metrics/mse", mse, epoch) #writer.add_scalar("metrics/qkappa", qk, epoch) #writer.add_scalar("metrics/mcc", mcc, epoch) #writer.add_scalar("metrics/accuracy",acc,epoch) # save checkpoint last_checkpoint = { "model": model.state_dict(), "optimizer": optimizer.state_dict(), "losses_tr": losses_tr, "losses_val": losses_val, "epochs": epoch + 1 }
def train_func(loader_train,loader_valid, args, writer=None, checkpoint=None): """ Start training from a configuration file and a given training set """ ENCODER_ARCH = args.get('encoder_architecture') PRETRAIN = args.get('weights') FREEZE_ENCODER = args.get('freeze_encoder') TRIAL_RUN = args.get('diagnostic_run') LR = args.get('learning_rate') EPOCHS = args.get('epochs') OPTIMIZER = args.get('optimizer') DF_ENC = args.get('lr_coef_encoder') LR_FACTOR = args.get('lr_scheduler_factor') RESUME = args.get('resume') print(f"check boolean: {FREEZE_ENCODER}") # ---------------------------------- # C - MODEL, CHECKPOINTS AND RELATED # ---------------------------------- if (not RESUME) and (PRETRAIN == "imagenet"): model = smp.Unet(ENCODER_ARCH, encoder_weights=PRETRAIN, decoder_attention_type="scse") print("starting training with iamgenet weights") else: model = smp.Unet(ENCODER_ARCH, decoder_attention_type="scse") if PRETRAIN == "lysto": load_lysto_weights(model, lysto_checkpt_path, "resnet50") segm_criterion = torch.nn.MSELoss() # %% start_ep = 0 if checkpoint is not None: model.load_state_dict(checkpoint["model"]) print(f"loaded checkpoint weights") model = nn.DataParallel(model) model.cuda(0) if FREEZE_ENCODER: for param in model.module.encoder.parameters(): param.requires_grad = False param_groups = [ {'params': model.module.encoder.parameters(), 'lr':LR*DF_ENC}, {'params': model.module.decoder.parameters(), 'lr':LR}, {'params': model.module.segmentation_head.parameters(), 'lr':LR}] # choosing the optimizer if OPTIMIZER == "adam": optimizer = torch.optim.Adam(param_groups) elif OPTIMIZER == "ranger": optimizer = Ranger(param_groups) else: raise ValueError("Specified optimizer not implemented") #if checkpoint is not None: #optimizer.load_state_dict(checkpoint["optimizer"]) # setting up lr scheduler if LR_FACTOR<1.0: lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=LR_FACTOR) losses_tr = [] losses_val = [] for epoch in range(start_ep, EPOCHS + start_ep): if TRIAL_RUN and (start_ep + epoch) >3: break model.train() print(f"Start training epochs {epoch + 1}..") for i, (img, mask) in enumerate(loader_train): if TRIAL_RUN and i>3: break # move data to device img = img.cuda().float() mask = mask.cuda().unsqueeze(1).float() # run inference out = model(img) # compute losses loss = segm_criterion(out, mask) # backward pass optimizer.zero_grad() loss.backward() optimizer.step() # log print(f"\rEpoch {epoch + 1}/{EPOCHS+start_ep} ({i+1}/{len(loader_train)})", end="", flush=True) # store losses losses_tr.append(loss.item()) if writer: writer.add_scalar("segmentation_loss/Train", loss.item(), epoch * len(loader_train) + i) print("") if True: print("\nValidation step...") segm_criterion = torch.nn.MSELoss() model.eval() with torch.no_grad(): reg_met = dict(pred=[],trgt=[]) # det_metrics = dict(precision=[], recall=[], f1=[]) loss_segm = [] for j, (img, mask) in enumerate(loader_valid): # move data to device img = img.cuda().float() mask = mask.cuda().unsqueeze(1).float() # run inference out = model(img) # compute losses segm_loss = segm_criterion(out, mask) counts_pred = out.sum(dim=[1,2,3]).detach().cpu().numpy() / 100. counts_gt = mask.sum(dim=[1,2,3]).detach().cpu().numpy() / 100. reg_met["pred"].extend(counts_pred) reg_met["trgt"].extend(counts_gt) loss_segm.append(segm_loss.item()) writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j) # writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j) # log print(f"\rValid epoch ({j+1}/{len(loader_valid)})",end="", flush=True) losses_val.append(np.mean(loss_segm)) cae, mae, mse = compute_reg_metrics(reg_met) qk, mcc, acc = compute_cls_metrics(reg_met) writer.add_scalar("metrics/cae", cae, epoch) writer.add_scalar("metrics/mae", mae, epoch) writer.add_scalar("metrics/mse", mse, epoch) writer.add_scalar("metrics/qkappa", qk, epoch) writer.add_scalar("metrics/mcc", mcc, epoch) writer.add_scalar("metrics/accuracy",acc,epoch) metrics = dict( cae=cae, mae=mae, mse=mse, qk=qk, mcc=mcc, acc=acc ) if LR_FACTOR<1.0: lr_scheduler.step(losses_val[-1]) return model, losses_tr, optimizer, epoch, metrics, loss_segm,
def train(gpu_proc_id, args, references): # train loop variables DATASET_DIR = references["DATASET_DIR"] lysto_checkpt_path = references["lysto_checkpt_path"] EXP_DIR = references["EXP_DIR"] exp_title = references["experiment_title"] ENCODER_ARCH = args.encoder_architecture PRETRAIN = args.weights FREEZE_ENCODER = args.freeze_encoder TRIAL_RUN = args.diagnostic_run LR = args.learning_rate EPOCHS = args.epochs BATCH_SIZE = args.batch_size LBL_SIGMA = args.label_sigma SCSE = args.decoder_scse OPTIMIZER = args.optimizer # start process and start coordination with nccl backend dist.init_process_group( backend='nccl', init_method='env://', world_size=len(args.gpu_id.split(",")), rank=gpu_proc_id ) print(f"Process on gpu {gpu_proc_id} has started") # set seed and log path torch.manual_seed(0) torch.cuda.set_device(gpu_proc_id) writer = SummaryWriter(log_dir=f"./tb_runs/{exp_title}") # %% # define albumentations transforms and datasets/datasamplers # execution is first to last transforms = A.Compose([ A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, always_apply=False, p=0.99), A.HorizontalFlip(p=0.5), A.VerticalFlip(p=0.5), ToTensorV2(), ]) transforms_valid = A.Compose([ ToTensorV2(), ]) dataset_train = DMapData(DATASET_DIR, "train", transforms, lbl_sigma_gauss=LBL_SIGMA) dataset_valid = DMapData(DATASET_DIR, "valid", transforms, lbl_sigma_gauss=LBL_SIGMA) # dataset_train.show_example(False) train_sampler = torch.utils.data.distributed.DistributedSampler( dataset_train, num_replicas=len(args.gpu_id.split(",")), rank=gpu_proc_id ) valid_sampler = torch.utils.data.distributed.DistributedSampler( dataset_valid, num_replicas=len(args.gpu_id.split(",")), rank=gpu_proc_id ) loader_train = DataLoader( dataset_train, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True, sampler=train_sampler ) loader_valid = DataLoader( dataset_valid, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, pin_memory=True, sampler=valid_sampler ) # checkpoints and weights # %% if (not args.resume) and (PRETRAIN == "imagenet"): model = smp.Unet(ENCODER_ARCH, encoder_weights=PRETRAIN, decoder_attention_type="scse") print("starting training with iamgenet weights") else: model = smp.Unet(ENCODER_ARCH, decoder_attention_type="scse") if PRETRAIN == "lysto": load_lysto_weights(model, lysto_checkpt_path, "resnet50") start_ep = 0 if args.resume: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint["model"]) print(f"loaded checkpoint {args.resume}") start_ep = checkpoint["epochs"] # %% # instance model, optimizer etc model.cuda(gpu_proc_id) model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu_proc_id], find_unused_parameters=True) if FREEZE_ENCODER: for param in model.module.encoder.parameters(): param.requires_grad = False optimizer = torch.optim.Adam([ {'params': model.module.encoder.parameters(), 'lr':LR}, {'params': model.module.decoder.parameters(), 'lr':LR}, {'params': model.module.segmentation_head.parameters(), 'lr':LR} ]) if args.resume: optimizer.load_state_dict(checkpoint["optimizer"]) segm_criterion = torch.nn.MSELoss() prob_cons_criterion = torch.nn.L1Loss() print("Starting training..") # %% losses_tr = dict(segment=[], conserv=[]) losses_val = dict(segment=[], conserv=[]) best_cons_loss = np.inf for epoch in range(start_ep, EPOCHS + start_ep): if TRIAL_RUN and epoch >3: break model.train() print("Training step..") for i, (img, mask) in enumerate(loader_train): if TRIAL_RUN and i>3: break # move data to device img = img.cuda(non_blocking=True).float() mask = mask.cuda(non_blocking=True).unsqueeze(1).float() # run inference out = model(img) # compute losses segm_loss = segm_criterion(out, mask) conservation_loss = prob_cons_criterion(out.sum(dim=[1,2,3]), mask.sum(dim=[1,2,3])) # losses aggregation loss = segm_loss #+ conservation_loss # backward pass optimizer.zero_grad() loss.backward() optimizer.step() # log print( f"\rEpoch {epoch + 1}/{EPOCHS+start_ep} ({i+1}/{len(loader_train)}) loss:{loss.item():.4f}|segm_loss:{segm_loss.item():.2f} |cons_loss: {conservation_loss.item():.2f}", end="", flush=True ) # store losses losses_tr["segment"].append(segm_loss.item()) losses_tr["conserv"].append(conservation_loss.item()) if gpu_proc_id == 0: writer.add_scalar("segmentation_loss/Train", segm_loss.item(), epoch * len(loader_train) + i) writer.add_scalar("regression_loss/Train", conservation_loss.item(), epoch * len(loader_train) + i) print("\nValidation step...") model.eval() with torch.no_grad(): reg_met = dict(pred=[],trgt=[]) for j, (img, mask) in enumerate(loader_valid): if TRIAL_RUN and j>3: break # move data to device img = img.cuda(non_blocking=True).float() mask = mask.cuda(non_blocking=True).unsqueeze(1).float() # run inference out = model(img) # compute losses segm_loss = segm_criterion(out, mask) conservation_loss = prob_cons_criterion(out.sum(dim=[1,2,3]), mask.sum(dim=[1,2,3])) counts_pred = out.sum(dim=[1,2,3]).detach().cpu().numpy() / 100. counts_gt = mask.sum(dim=[1,2,3]).detach().cpu().numpy() / 100. reg_met["pred"].extend(counts_pred) reg_met["trgt"].extend(counts_gt) # store losses losses_val["segment"].append(segm_loss.item()) losses_val["conserv"].append(conservation_loss.item()) if gpu_proc_id == 0: writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j) writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j) # log print( f"\rValid epoch {epoch + 1}/{EPOCHS + start_ep} ({j+1}/{len(loader_valid)}) loss:{loss.item():.4f}|segm_loss:{segm_loss.item():.2f} |cons_loss: {conservation_loss.item():.2f}", end="", flush=True ) if gpu_proc_id == 0: cae, mae, mse = compute_reg_metrics(reg_met) qk, mcc, acc = compute_cls_metrics(reg_met) writer.add_scalar("metrics/cae", cae, epoch) writer.add_scalar("metrics/mae", mae, epoch) writer.add_scalar("metrics/mse", mse, epoch) writer.add_scalar("metrics/qkappa", qk, epoch) writer.add_scalar("metrics/mcc", mcc, epoch) writer.add_scalar("metrics/accuracy",acc,epoch) if gpu_proc_id == 0: # save checkpoint last_checkpoint = { "model":model.state_dict(), "optimizer":optimizer.state_dict(), "losses_tr":losses_tr, "losses_val":losses_val, "epochs":epoch+1 } avg_val_loss = np.mean(losses_val["conserv"][-len(loader_valid):]) if avg_val_loss < best_cons_loss: best_cons_loss = avg_val_loss name = "best" else: name = "last" torch.save(last_checkpoint, EXP_DIR + name + ".pth")
precision = len(matches) / len(pred) if len(pred) else 0 f_one = 2 / (1/(recall + 1e-6) + 1/(precision+1e-6)) # else: precision, recall, f_one = 0,0,0 metrics["recall"].append(recall) metrics["precision"].append(precision) metrics["f1"].append(f_one) tt = time.time() -ti # %% preds_num = np.array(preds_num) gt_num = np.array(gt_num) cae, mae, mse = compute_reg_metrics(dict(pred=preds_num, trgt=gt_num)) qk, mcc, acc = compute_cls_metrics(dict(pred=preds_num, trgt=gt_num)) metrics["cae"] = cae metrics["mae"] = mae metrics["mse"] = mse metrics["qk"] = qk metrics["mcc"] = mcc metrics["acc"] = acc # %% print(f"Completed inference on {len(dataset_valid)} tile in {tt:.2f}: {len(dataset_valid)/tt:.2f} fps") print(Path(checkpoint_path).stem) print("Average metrics over validation set") for k,v in metrics.items(): print(f"{k:10s}: {np.mean(v):.2f}") # %%