示例#1
0
def validate(loader_valid,model, writer):
    print("\nValidation step...")
    segm_criterion = torch.nn.MSELoss()
    model.eval()
    with torch.no_grad():
        reg_met = dict(pred=[],trgt=[])
        # det_metrics = dict(precision=[], recall=[], f1=[])
        loss_segm = []
        for j, (img, mask) in enumerate(loader_valid):
            # move data to device
            img = img.cuda().float()
            mask = mask.cuda().unsqueeze(1).float()

            # run inference
            out = model(img)
            # compute losses
            segm_loss = segm_criterion(out, mask)

            counts_pred = out.sum(dim=[1,2,3]).detach().cpu().numpy() / 100.
            counts_gt = mask.sum(dim=[1,2,3]).detach().cpu().numpy() / 100.
            reg_met["pred"].extend(counts_pred)
            reg_met["trgt"].extend(counts_gt)
            loss_segm.append(segm_loss.item())
    
            
            # writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j)
            # writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j)


            # log 
            print(f"\rValid epoch ({j+1}/{len(loader_valid)})",end="", flush=True)
        #losses_val.append(np.mean(loss_segm))
        cae, mae, mse = compute_reg_metrics(reg_met)
        qk, mcc, acc = compute_cls_metrics(reg_met)
        # writer.add_scalar("metrics/cae", cae, epoch)
        # writer.add_scalar("metrics/mae", mae, epoch)
        # writer.add_scalar("metrics/mse", mse, epoch)
        # writer.add_scalar("metrics/qkappa", qk, epoch)
        # writer.add_scalar("metrics/mcc", mcc, epoch)
        # writer.add_scalar("metrics/accuracy",acc,epoch)
        metrics = dict(
            cae=cae,
            mae=mae,
            mse=mse,
            qk=qk,
            mcc=mcc,
            acc=acc
        )
    return model, losses_tr, optimizer, epoch, metrics, loss_segm
示例#2
0
            counts_gt = mask.sum(dim=[1, 2, 3]).detach().cpu().numpy() / 100.
            reg_met["pred"].extend(counts_pred)
            reg_met["trgt"].extend(counts_gt)

            # store losses
            losses_val["segment"].append(segm_loss.item())
            losses_val["conserv"].append(conservation_loss.item())
            #writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j)
            #writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j)

            # log
            print(
                f"\rValid epoch {epoch + 1}/{EPOCHS + start_ep} ({j+1}/{len(loader_valid)}) loss:{loss.item():.4f}|segm_loss:{segm_loss.item():.2f} |cons_loss: {conservation_loss.item():.2f}",
                end="",
                flush=True)
        cae, mae, mse = compute_reg_metrics(reg_met)
        qk, mcc, acc = compute_cls_metrics(reg_met)
        #writer.add_scalar("metrics/cae", cae, epoch)
        #writer.add_scalar("metrics/mae", mae, epoch)
        #writer.add_scalar("metrics/mse", mse, epoch)
        #writer.add_scalar("metrics/qkappa", qk, epoch)
        #writer.add_scalar("metrics/mcc", mcc, epoch)
        #writer.add_scalar("metrics/accuracy",acc,epoch)

    # save checkpoint
    last_checkpoint = {
        "model": model.state_dict(),
        "optimizer": optimizer.state_dict(),
        "losses_tr": losses_tr,
        "losses_val": losses_val,
        "epochs": epoch + 1
示例#3
0
def train_func(loader_train,loader_valid, args, writer=None, checkpoint=None):
    """
    Start training from a configuration file and a given training set
    """

    ENCODER_ARCH = args.get('encoder_architecture')
    PRETRAIN = args.get('weights')
    FREEZE_ENCODER = args.get('freeze_encoder')
    TRIAL_RUN = args.get('diagnostic_run')
    LR = args.get('learning_rate')
    EPOCHS = args.get('epochs')
    OPTIMIZER = args.get('optimizer')
    DF_ENC = args.get('lr_coef_encoder')
    LR_FACTOR = args.get('lr_scheduler_factor')
    RESUME = args.get('resume')
    print(f"check boolean: {FREEZE_ENCODER}")

    # ----------------------------------
    # C - MODEL, CHECKPOINTS AND RELATED
    # ----------------------------------
    if (not RESUME) and (PRETRAIN == "imagenet"):
        model = smp.Unet(ENCODER_ARCH, encoder_weights=PRETRAIN, decoder_attention_type="scse")
        print("starting training with iamgenet weights")
    else:
        model = smp.Unet(ENCODER_ARCH, decoder_attention_type="scse")

    if PRETRAIN == "lysto":
        load_lysto_weights(model, lysto_checkpt_path, "resnet50")

    segm_criterion = torch.nn.MSELoss()


    # %%
    start_ep = 0
    if checkpoint is not None:
        model.load_state_dict(checkpoint["model"])
        print(f"loaded checkpoint weights")

    model = nn.DataParallel(model)
    model.cuda(0)

    if FREEZE_ENCODER:
        for param in model.module.encoder.parameters():
            param.requires_grad = False

    param_groups = [
                    {'params': model.module.encoder.parameters(), 'lr':LR*DF_ENC},
                    {'params': model.module.decoder.parameters(), 'lr':LR},
                    {'params': model.module.segmentation_head.parameters(), 'lr':LR}]

    

    # choosing the optimizer
    if OPTIMIZER == "adam":
        optimizer = torch.optim.Adam(param_groups)
    elif OPTIMIZER == "ranger":
        optimizer = Ranger(param_groups)
    else:
        raise ValueError("Specified optimizer not implemented")

    #if checkpoint is not None:
        #optimizer.load_state_dict(checkpoint["optimizer"])

    # setting up lr scheduler
    if LR_FACTOR<1.0:
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=LR_FACTOR)
    losses_tr = []
    losses_val = []
    for epoch in range(start_ep, EPOCHS + start_ep):
        if TRIAL_RUN and (start_ep + epoch) >3:
            break

        model.train()
        print(f"Start training epochs {epoch + 1}..")
        for i, (img, mask) in enumerate(loader_train):
            if TRIAL_RUN and i>3:
                break

            # move data to device
            img = img.cuda().float()
            mask = mask.cuda().unsqueeze(1).float()

            # run inference
            out = model(img)
            # compute losses
            loss = segm_criterion(out, mask)

            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # log 
            print(f"\rEpoch {epoch + 1}/{EPOCHS+start_ep} ({i+1}/{len(loader_train)})", end="", flush=True)
            # store losses
            losses_tr.append(loss.item())
            if writer:
                writer.add_scalar("segmentation_loss/Train", loss.item(), epoch * len(loader_train) + i)
        print("")


        if True:
            print("\nValidation step...")
            segm_criterion = torch.nn.MSELoss()
            model.eval()
            with torch.no_grad():
                reg_met = dict(pred=[],trgt=[])
                # det_metrics = dict(precision=[], recall=[], f1=[])
                loss_segm = []
                for j, (img, mask) in enumerate(loader_valid):
                    # move data to device
                    img = img.cuda().float()
                    mask = mask.cuda().unsqueeze(1).float()

                    # run inference
                    out = model(img)
                    # compute losses
                    segm_loss = segm_criterion(out, mask)

                    counts_pred = out.sum(dim=[1,2,3]).detach().cpu().numpy() / 100.
                    counts_gt = mask.sum(dim=[1,2,3]).detach().cpu().numpy() / 100.
                    reg_met["pred"].extend(counts_pred)
                    reg_met["trgt"].extend(counts_gt)
                    loss_segm.append(segm_loss.item())
            
                    
                    writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j)
                    # writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j)


                    # log 
                    print(f"\rValid epoch ({j+1}/{len(loader_valid)})",end="", flush=True)
                losses_val.append(np.mean(loss_segm))
                cae, mae, mse = compute_reg_metrics(reg_met)
                qk, mcc, acc = compute_cls_metrics(reg_met)
                writer.add_scalar("metrics/cae", cae, epoch)
                writer.add_scalar("metrics/mae", mae, epoch)
                writer.add_scalar("metrics/mse", mse, epoch)
                writer.add_scalar("metrics/qkappa", qk, epoch)
                writer.add_scalar("metrics/mcc", mcc, epoch)
                writer.add_scalar("metrics/accuracy",acc,epoch)
                metrics = dict(
                    cae=cae,
                    mae=mae,
                    mse=mse,
                    qk=qk,
                    mcc=mcc,
                    acc=acc
                )

        if LR_FACTOR<1.0:
            lr_scheduler.step(losses_val[-1])

    return model, losses_tr, optimizer, epoch, metrics, loss_segm, 
def train(gpu_proc_id, args, references):
    
    # train loop variables

    DATASET_DIR = references["DATASET_DIR"]
    lysto_checkpt_path = references["lysto_checkpt_path"]
    EXP_DIR = references["EXP_DIR"]
    exp_title = references["experiment_title"]

    ENCODER_ARCH = args.encoder_architecture
    PRETRAIN = args.weights
    FREEZE_ENCODER = args.freeze_encoder
    TRIAL_RUN = args.diagnostic_run
    LR = args.learning_rate
    EPOCHS = args.epochs
    BATCH_SIZE = args.batch_size
    LBL_SIGMA = args.label_sigma
    SCSE = args.decoder_scse
    OPTIMIZER = args.optimizer


    # start process and start coordination with nccl backend
                    
    dist.init_process_group(                                   
    	backend='nccl',                                         
   		init_method='env://',                                   
    	world_size=len(args.gpu_id.split(",")),                              
    	rank=gpu_proc_id                                               
    )  
    print(f"Process on gpu {gpu_proc_id} has started")

    # set seed and log path

    torch.manual_seed(0)
    torch.cuda.set_device(gpu_proc_id)
    writer = SummaryWriter(log_dir=f"./tb_runs/{exp_title}")

    # %%
    # define albumentations transforms and datasets/datasamplers
    # execution is first to last

    transforms = A.Compose([
        A.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.05, always_apply=False, p=0.99),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        ToTensorV2(),
    ])
    transforms_valid = A.Compose([
        ToTensorV2(),
    ])

    dataset_train = DMapData(DATASET_DIR, "train", transforms, lbl_sigma_gauss=LBL_SIGMA)
    dataset_valid = DMapData(DATASET_DIR, "valid", transforms, lbl_sigma_gauss=LBL_SIGMA)
    # dataset_train.show_example(False)

    train_sampler = torch.utils.data.distributed.DistributedSampler(
    	dataset_train,
    	num_replicas=len(args.gpu_id.split(",")),
    	rank=gpu_proc_id
    )
    valid_sampler = torch.utils.data.distributed.DistributedSampler(
    	dataset_valid,
    	num_replicas=len(args.gpu_id.split(",")),
    	rank=gpu_proc_id
    )

    loader_train = DataLoader(
        dataset_train, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=0, pin_memory=True, sampler=train_sampler
        )
    loader_valid = DataLoader(
        dataset_valid, batch_size=BATCH_SIZE, shuffle=False,
        num_workers=0, pin_memory=True, sampler=valid_sampler
        )

    # checkpoints and weights
    # %%
    if (not args.resume) and (PRETRAIN == "imagenet"):
        model = smp.Unet(ENCODER_ARCH, encoder_weights=PRETRAIN, decoder_attention_type="scse")
        print("starting training with iamgenet weights")
    else:
        model = smp.Unet(ENCODER_ARCH, decoder_attention_type="scse")

    if PRETRAIN == "lysto":
        load_lysto_weights(model, lysto_checkpt_path, "resnet50")
    

    start_ep = 0
    if args.resume:
        checkpoint = torch.load(args.resume)
        model.load_state_dict(checkpoint["model"])
        print(f"loaded checkpoint {args.resume}")
        start_ep = checkpoint["epochs"]

    # %%
    # instance model, optimizer etc
    model.cuda(gpu_proc_id)
    model = nn.parallel.DistributedDataParallel(model, device_ids=[gpu_proc_id], find_unused_parameters=True)

    if FREEZE_ENCODER:
        for param in model.module.encoder.parameters():
            param.requires_grad = False

    optimizer = torch.optim.Adam([
                    {'params': model.module.encoder.parameters(), 'lr':LR},
                    {'params': model.module.decoder.parameters(), 'lr':LR},
                    {'params': model.module.segmentation_head.parameters(), 'lr':LR}
    ])

    if args.resume:
        optimizer.load_state_dict(checkpoint["optimizer"])

    segm_criterion = torch.nn.MSELoss()
    prob_cons_criterion = torch.nn.L1Loss()


    print("Starting training..")
    # %%
    losses_tr = dict(segment=[], conserv=[])
    losses_val = dict(segment=[], conserv=[])
    best_cons_loss = np.inf
    for epoch in range(start_ep, EPOCHS + start_ep):
        if TRIAL_RUN and epoch >3:
            break

        model.train()
        print("Training step..")
        for i, (img, mask) in enumerate(loader_train):
            if TRIAL_RUN and i>3:
                break
            # move data to device
            img = img.cuda(non_blocking=True).float()
            mask = mask.cuda(non_blocking=True).unsqueeze(1).float()

            # run inference
            out = model(img)
            # compute losses
            segm_loss = segm_criterion(out, mask)
            conservation_loss = prob_cons_criterion(out.sum(dim=[1,2,3]), mask.sum(dim=[1,2,3]))
            # losses aggregation
            loss = segm_loss #+ conservation_loss

            # backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # log 
            print(
                f"\rEpoch {epoch + 1}/{EPOCHS+start_ep} ({i+1}/{len(loader_train)}) loss:{loss.item():.4f}|segm_loss:{segm_loss.item():.2f} |cons_loss: {conservation_loss.item():.2f}",
                end="", flush=True
                )
            # store losses
            losses_tr["segment"].append(segm_loss.item())
            losses_tr["conserv"].append(conservation_loss.item())
            if gpu_proc_id == 0:
                writer.add_scalar("segmentation_loss/Train", segm_loss.item(), epoch * len(loader_train) + i)
                writer.add_scalar("regression_loss/Train", conservation_loss.item(), epoch * len(loader_train) + i)

        print("\nValidation step...")
        model.eval()
        with torch.no_grad():
            reg_met = dict(pred=[],trgt=[])

            for j, (img, mask) in enumerate(loader_valid):
                if TRIAL_RUN and j>3:
                    break
                # move data to device
                img = img.cuda(non_blocking=True).float()
                mask = mask.cuda(non_blocking=True).unsqueeze(1).float()

                # run inference
                out = model(img)
                # compute losses
                segm_loss = segm_criterion(out, mask)
                conservation_loss = prob_cons_criterion(out.sum(dim=[1,2,3]), mask.sum(dim=[1,2,3]))


                counts_pred = out.sum(dim=[1,2,3]).detach().cpu().numpy() / 100.
                counts_gt = mask.sum(dim=[1,2,3]).detach().cpu().numpy() / 100.
                reg_met["pred"].extend(counts_pred)
                reg_met["trgt"].extend(counts_gt)

                # store losses
                losses_val["segment"].append(segm_loss.item())
                losses_val["conserv"].append(conservation_loss.item())
                if gpu_proc_id == 0:
                    writer.add_scalar("segmentation_loss/Valid", segm_loss.item(), epoch * len(loader_valid) + j)
                    writer.add_scalar("regression_loss/Valid", conservation_loss.item(), epoch * len(loader_valid) + j)


                # log 
                print(
                    f"\rValid epoch {epoch + 1}/{EPOCHS + start_ep} ({j+1}/{len(loader_valid)}) loss:{loss.item():.4f}|segm_loss:{segm_loss.item():.2f} |cons_loss: {conservation_loss.item():.2f}",
                    end="", flush=True
                    )
            if gpu_proc_id == 0:
                cae, mae, mse = compute_reg_metrics(reg_met)
                qk, mcc, acc = compute_cls_metrics(reg_met)
                writer.add_scalar("metrics/cae", cae, epoch)
                writer.add_scalar("metrics/mae", mae, epoch)
                writer.add_scalar("metrics/mse", mse, epoch)
                writer.add_scalar("metrics/qkappa", qk, epoch)
                writer.add_scalar("metrics/mcc", mcc, epoch)
                writer.add_scalar("metrics/accuracy",acc,epoch)

        


        if gpu_proc_id == 0:
            # save checkpoint
            last_checkpoint = {
                "model":model.state_dict(),
                "optimizer":optimizer.state_dict(),
                "losses_tr":losses_tr,
                "losses_val":losses_val,
                "epochs":epoch+1
            }

            avg_val_loss = np.mean(losses_val["conserv"][-len(loader_valid):])
            if avg_val_loss < best_cons_loss:
                best_cons_loss = avg_val_loss
                name = "best"
            else:
                name = "last"

            torch.save(last_checkpoint, EXP_DIR + name + ".pth")
示例#5
0
            recall = len(matches) / len(gt) if len(gt) else 0
            precision = len(matches) / len(pred) if len(pred) else 0 
            f_one = 2 / (1/(recall + 1e-6) + 1/(precision+1e-6))
            #
        else:
            precision, recall, f_one = 0,0,0

        metrics["recall"].append(recall)
        metrics["precision"].append(precision)
        metrics["f1"].append(f_one)
tt = time.time() -ti
# %%

preds_num = np.array(preds_num)
gt_num = np.array(gt_num)
cae, mae, mse = compute_reg_metrics(dict(pred=preds_num, trgt=gt_num))
qk, mcc, acc = compute_cls_metrics(dict(pred=preds_num, trgt=gt_num))
metrics["cae"] = cae
metrics["mae"] = mae
metrics["mse"] = mse
metrics["qk"] = qk
metrics["mcc"] = mcc
metrics["acc"] = acc
# %%
print(f"Completed inference on {len(dataset_valid)} tile in {tt:.2f}: {len(dataset_valid)/tt:.2f} fps")
print(Path(checkpoint_path).stem)
print("Average metrics over validation set")
for k,v in metrics.items():
    print(f"{k:10s}: {np.mean(v):.2f}")