Пример #1
0
def predict(net, full_img, device, input_size, mask_way='warp'):
    '''
    :mask_type: Sets the way to obtain the mask. Сan take 'warp' or 'segm'
    '''

    # Preprocess input image:
    img = BasicDataset.preprocess_img(full_img, input_size)
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    net.eval()

    # Predict:
    with torch.no_grad():
        logits, rec_mask, theta = net.predict(
            img, warp=True if mask_way == 'warp' else False)

    if mask_way == 'warp':
        mask = rec_mask * net.n_classes
        mask = mask.type(torch.IntTensor).cpu().numpy().astype(np.uint8)
    elif mask_way == 'segm':
        mask = preds_to_masks(logits, net.n_classes)
    else:
        raise NotImplementedError

    return mask, theta
Пример #2
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        probs = F.softmax(output, dim=1)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #3
0
def getDataSet(inputDir, labelDir, imgScale, valPercent):
    dataset = BasicDataset(inputDir, labelDir, imgScale)
    # We use va_percent for validation and the rest for training
    n_val = int(len(dataset) * valPercent)
    n_train = len(dataset) - n_val
    # We randomly split the data to train/validation
    train, val = random_split(dataset, [n_train, n_val])
    return (train, val)
Пример #4
0
def predict_img(net, full_img, device, input_size):
    net.eval()

    img = BasicDataset.preprocess_img(full_img, input_size)
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        preds = net(img)
        masks = preds_to_masks(preds, net.n_classes)  # GPU tensor -> CPU numpy

    return masks
Пример #5
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                use_dense_crf=False):
    net.eval()

    ds = BasicDataset('', '', scale=scale_factor)
    img = torch.from_numpy(ds.preprocess(full_img))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(full_img.shape[1]),
                transforms.ToTensor()
            ]
        )

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    if use_dense_crf:
        full_mask = dense_crf(np.array(full_img).astype(np.uint8), full_mask)

    return full_mask > out_threshold
def prediction(net, imgs, device):
    net.eval()
    ds = BasicDataset('data/training/img',
                      'data/training/full_mask',
                      scale=0.5)
    img = ds.preprocess(imgs)
    img = torch.from_numpy(img)
    img = torch.unsqueeze(img, 0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(imgs.size[1]),
            transforms.ToTensor()
        ])
        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()
    return full_mask > 0.5
Пример #7
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(
        BasicDataset.preprocess(full_img, scale_factor, False))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        output_seg = output.max(dim=1)[1].unsqueeze(1)
        output_seg = output_seg.data.cpu().numpy()
        return output_seg[0, 0, :, :]
Пример #8
0
def plot_imgs_pred():
    """
    Funzione
    :return:
    """
    args = get_plot_args()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # img = Image.open(args.dir_img)
    img = tiff.imread(args.dir_img)
    img = BasicDataset.preprocess(img, scale=args.scale)
    img = torch.from_numpy(img).type(torch.FloatTensor)

    if args.model_arch == 'unet':
        net = UNet(n_channels=4, n_classes=4, bilinear=True)

    elif args.model_arch == 'icnet':
        net = ICNet(n_channels=4, n_classes=4, pretrained_base=False)

    net.load_state_dict(torch.load(args.checkpoint_net, map_location=device))
    net.to(device=device)
    net.eval()

    img = img.to(device=device, dtype=torch.float32)
    img = img.unsqueeze(0)

    with torch.no_grad():
        if args.model_arch == 'icnet':
            mask_pred, pred_sub4, pred_sub8, pred_sub16 = net(img)
        else:
            mask_pred = net(img)

    plt.imshow(img[0][0])
    plt.colorbar()
    plt.savefig(args.dir_output + "original_img.png")
    plt.clf()

    for i, c in enumerate(mask_pred):
        n_classes = c.size(0)
        classes = range(n_classes)
        c = torch.sigmoid(c)
        max_index = torch.max(c, 0).indices
        for class_index in classes:
            # Vediamo la predizione
            jaccard_input = (max_index == class_index).float()
            plt.imshow(jaccard_input)
            plt.colorbar()
            plt.savefig(args.dir_output + f"pred_cls_{class_index}.png")
            plt.clf()
def show_gt():
    batch_size = 1
    img_scale = 1
    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    train_loader = DataLoader(dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    for batch in train_loader:
        img = batch['image'].numpy()[0].transpose(1, 2, 0)
        mask = batch['mask'].numpy()[0].transpose(1, 2, 0)
        cv2.imshow("img", img)
        cv2.imshow("mask", mask)
        cv2.waitKey()
Пример #10
0
def fun1():
    dir_img = 'data/imgs/'
    dir_mask = 'data/masks/'
    img_scale=0.5
    val_percent=0.1
    batch_size = 1
    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train, batch_size=batch_size, shuffle=True, num_workers=8, pin_memory=True)

    for batch in train_loader:
        imgs = batch['image']
        true_masks = batch['mask']
        print(imgs.shape)
        exit(0)
Пример #11
0
def predict_img(net, full_img, device, input_size):
    # Preprocess input image:
    img = BasicDataset.preprocess_img(full_img, input_size)
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    net.eval()

    # Predict:
    with torch.no_grad():
        mask_pred, mask_proj = net(img)

    # Tensors to ndarrays:
    mask = preds_to_masks(mask_pred, net.n_classes)
    proj = mask_proj * net.n_classes
    proj = proj.type(torch.IntTensor).cpu().numpy().astype(np.uint8)

    return mask, proj
Пример #12
0
def validation_only(net,
                    device,
                    batch_size=1,
                    img_width=0, 
                    img_height=0,
                    img_scale=1.0,
                    use_bw=False,
                    standardize=False,
                    compute_statistics=False):

    load_statstics = not compute_statistics
    dataset = BasicDataset(dir_img_test, dir_mask_test, img_width, img_height, img_scale, use_bw,
                           standardize=standardize, load_statistics=load_statstics, save_statistics=True)
    val_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True, drop_last=True)
    val_score = eval_net(net, val_loader, device)
    if net.n_classes > 1:
        logging.info('Validation cross entropy: {}'.format(val_score))
    else:
        logging.info('Validation Dice Coeff: {}'.format(val_score))
Пример #13
0
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    # img_input = transforms.ToPILImage()(img.type(torch.float32)).convert('RGB')
    # img_input.save('input.jpg')
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        _, probs = torch.max(probs, dim=0)
        print(f'probs.max():{probs.max()}')
        probs = probs.unsqueeze(0) * 10
        probs = probs.type(torch.float32)

        # tf = transforms.Compose(
        #     [
        #         transforms.ToPILImage(),
        #         transforms.Resize(full_img.size[1]),
        #         transforms.ToTensor()
        #     ]
        # )
        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.0):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)
        probs = probs.squeeze(0)
        print(probs)
        full_mask = probs.cpu().numpy()
        print(type(full_mask))
        print("*******************************************************")
        """
        _, (ax1, ax2, ax3, ax4, ax5,ax6,ax7,ax8,ax9,ax10,ax11,ax12,ax13,ax14) = plt.subplots(1, 14, sharey=True)
        ax1.imshow(full_mask[0,:,:].squeeze())
        ax2.imshow(full_mask[1,:,:].squeeze())
        ax3.imshow(full_mask[2,:,:].squeeze())
        ax4.imshow(full_mask[3,:,:].squeeze())
        ax5.imshow(full_mask[4,:,:].squeeze())
        ax6.imshow(full_mask[5, :, :].squeeze())
        ax7.imshow(full_mask[6, :, :].squeeze())
        ax8.imshow(full_mask[7, :, :].squeeze())
        ax9.imshow(full_mask[8, :, :].squeeze())
        ax10.imshow(full_mask[9, :, :].squeeze())
        ax11.imshow(full_mask[10, :, :].squeeze())
        ax12.imshow(full_mask[11, :, :].squeeze())
        ax13.imshow(full_mask[12, :, :].squeeze())
        ax14.imshow(full_mask[13, :, :].squeeze())
        """

        print(full_mask)
        full_mask = np.argmax(full_mask, axis=0)
        print("--***********************************************")
        print(full_mask.shape)
    return full_mask
Пример #15
0
def predict_img(net,
                full_img,
                device,
                scale_factor=0.5,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        print(output.shape)
        #if net.n_classes > 1:
        #    probs = F.softmax(output, dim=1)
        #else:
        #    probs = torch.sigmoid(output)
        probs_0 = torch.sigmoid(output[:, 0, :, :])
        probs_1 = torch.sigmoid(output[:, 1, :, :])

        probs_0 = probs_0.squeeze(0)
        probs_1 = probs_1.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                #transforms.Resize(full_img.width),#size[1]),
                transforms.Resize((full_img.height, full_img.width)),
                transforms.ToTensor()
            ]
        )

        probs_0 = tf(probs_0.cpu())
        probs_1 = tf(probs_1.cpu())
        mask_0 = probs_0.squeeze().cpu().numpy()
        mask_1 = probs_1.squeeze().cpu().numpy()
        full_mask = np.array([mask_0, mask_1])#probs.squeeze().cpu().numpy()

    return full_mask# > out_threshold
Пример #16
0
def predict_img(net,
                full_img1,
                full_img2,
                full_img3,
                full_img4,
                full_img5,
                device,
                scale_factor=0.267,
                out_threshold=0.6):
    net.eval()

    img = torch.from_numpy(
        BasicDataset.preprocess(full_img1, full_img2, full_img3, full_img4,
                                full_img5, scale_factor))
    img_pro = (img).squeeze(0).numpy()
    img_show = sitk.GetImageFromArray(img_pro)
    #sitk.WriteImage(img_show, './data/pred/scale0.2_002input.nii')
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        max_out = output.cpu().numpy().max()

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)
            max0 = probs.cpu().numpy().max()

        probs = output.squeeze()

        tf = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize((481,481,481)),
            transforms.ToTensor()
        ])

        #probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()
    return full_mask
Пример #17
0
def test_net(net, device, batch_size=4, scale=512, threshold=0.5):

    dataset = BasicDataset(dir_img, dir_mask, 512, False, 5)
    loader = DataLoader(dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        num_workers=8,
                        pin_memory=True)

    tm = TimeManager()

    val_score, precision, recall = eval_net(net, loader, device, threshold)

    if net.n_classes > 1:
        print('Validation cross entropy:', val_score)
    else:
        print('Validation Dice Coeff:', val_score)
        print('Validation Precision:', precision)
        print('Validation Recall:', recall)

    tm.show()
Пример #18
0
def inference_one(net, image, device):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(image, cfg.scale))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)
        if cfg.deepsupervision:
            output = output[-1]

        if cfg.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)        # C x H x W

        tf = transforms.Compose(
                [
                    transforms.ToPILImage(),
                    transforms.Resize((image.size[1], image.size[0])),
                    transforms.ToTensor()
                ]
        )

        if cfg.n_classes == 1:
            probs = tf(probs.cpu())
            mask = probs.squeeze().cpu().numpy()
            return mask > cfg.out_threshold
        else:
            masks = []
            for prob in probs:
                prob = tf(prob.cpu())
                mask = prob.squeeze().cpu().numpy()
                mask = mask > cfg.out_threshold
                masks.append(mask)
            return masks
Пример #19
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        # output = net(img)
        centerlines, points = net(img)

        if net.n_classes > 1:
            probs1 = F.softmax(centerlines, dim=1)
            probs2 = F.softmax(points, dim=1)
        else:
            probs1 = torch.sigmoid(centerlines)
            probs2 = torch.sigmoid(points)

        probs1 = probs1.squeeze(0)
        probs2 = probs2.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(full_img.size[1]),
                transforms.ToTensor()
            ]
        )

        probs1 = tf(probs1.cpu())
        probs2 = tf(probs2.cpu())
        full_centerlines = probs1.squeeze().cpu().numpy()
        full_points = probs2.squeeze().cpu().numpy()

    return full_centerlines > out_threshold, full_points > out_threshold
def predict_img(net, full_img, device, scale_factor=1, out_threshold=0.0):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))
    print(img.shape)
    img = img.unsqueeze(0)
    print(img.shape)
    img = img.to(device=device, dtype=torch.float32)
    with torch.no_grad():
        output = net(img)
        print(output)
        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
            print(probs.shape)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        print("probsss", probs)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())

        sum1 = 0
        full_mask = probs.squeeze().cpu().numpy()
        """
        for i in range(0,375):
            sum1 += (full_mask[0][0][i])
        print(sum1)
        """
        print("*******", full_mask)
        print(out_threshold)
    return np.argmax(full_mask, axis=0)
Пример #21
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        # if net.n_classes > 1:
        #     probs = F.softmax(output, dim=1)
        # else:
        #     probs = torch.sigmoid(output)
        probs = torch.sigmoid(output)
        probs = probs.squeeze(0)

        # tf = transforms.Compose(
        #     [
        #         transforms.ToPILImage(),
        #         transforms.Resize(full_img.size[1]),
        #         transforms.ToTensor()
        #     ]
        # )
        #
        # probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

        import matplotlib.pyplot as plt
        plt.figure()
        plt.imshow(full_mask[2:,:,:].argmax(0)), plt.colorbar()
        plt.show()
    return full_mask #> out_threshold
Пример #22
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose(
            [
                transforms.ToPILImage(),
                transforms.Resize(full_img.size[1]),
                transforms.ToTensor()
            ]
        )

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    result = np.zeros(full_mask.shape, dtype=np.bool)
    for i, thres in enumerate(out_threshold):
        result[i] = full_mask[i] > out_threshold[i]
    return result
Пример #23
0
def predict_img(net,
                full_img,
                device,
                scale_factor=1,
                out_threshold=0.5,
                outf=None):
    net.eval()
    transform = transforms.Compose([transforms.Resize((128, 128))])
    img = torch.from_numpy(
        BasicDataset.preprocess(full_img, scale_factor, transform))
    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)

        if net.n_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)

        tf = transforms.Compose([
            transforms.ToPILImage(),
            #transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    outfn = outf.split('/')
    outtrain_fn = "./data/predicted/train128_" + outfn[3]
    save_image(img, outtrain_fn)

    return full_mask > out_threshold
Пример #24
0
def predict_img(net,
                full_img,
                device,
                file_name,
                scale_factor=1,
                out_threshold=0.5):
    net.eval()

    img = torch.from_numpy(BasicDataset.preprocess(full_img, scale_factor))

    img = img.unsqueeze(0)
    img = img.to(device=device, dtype=torch.float32)

    with torch.no_grad():
        output = net(img)['out']

        if net.num_classes > 1:
            probs = F.softmax(output, dim=1)
        else:
            probs = torch.sigmoid(output)

        probs = probs.squeeze(0)
        save_array = probs.cpu().numpy()
        mat_prob = np.reshape(save_array, [300, 300])
        save_fn = 'D:/users/otis/MedicalImage_Project02_Segmentation/MedicalImage_Project02_Segmentation/private_data_10/private_data_10/Results' + file_name[:
                                                                                                                                                              -4] + '_prob.mat'
        sio.savemat(save_fn, {'array': mat_prob})
        tf = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(full_img.size[1]),
            transforms.ToTensor()
        ])

        probs = tf(probs.cpu())
        full_mask = probs.squeeze().cpu().numpy()

    return full_mask > out_threshold
Пример #25
0
def infer(args, unlabeled, ckpt_file):
    # Load the last best model
    traindataset = BasicDataset(
        args["TRAINIMAGEDATA_DIR"], args["TRAINLABEL_DIRECTORY"], img_scale
    )
    unlableddataset = Subset(traindataset, unlabeled)
    unlabeled_loader = DataLoader(
        unlableddataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=8,
        pin_memory=True,
    )
    predix = 0
    predictions = {}
    net = UNet(n_channels=3, n_classes=1, bilinear=True)
    net.to(device=device)
    net.load_state_dict(torch.load(os.path.join(args["EXPT_DIR"] + ckpt_file)))
    net.eval()

    with tqdm(total=n_val, desc="Validation round", unit="batch", leave=False) as pbar:
        for batch in val_loader:
            imgs, true_masks = batch["image"], batch["mask"]
            imgs = imgs.to(device=device, dtype=torch.float32)
            true_masks = true_masks.to(device=device, dtype=mask_type)

            with torch.no_grad():
                mask_pred = net(imgs)
            for ix, logit in enumerate(maskpred):
                predictions[predix] = logit.cpu().numpy()

                predix += 1

            pbar.update()

    return {"outputs": predictions}
Пример #26
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.001,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    writer = SummaryWriter(
        comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')
    global_step = 0

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(),
                              lr=lr,
                              weight_decay=1e-8,
                              momentum=0.9)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1
                if global_step % (n_train // (10 * batch_size)) == 0:
                    for tag, value in net.named_parameters():
                        tag = tag.replace('.', '/')
                        writer.add_histogram('weights/' + tag,
                                             value.data.cpu().numpy(),
                                             global_step)
                        writer.add_histogram('grads/' + tag,
                                             value.grad.data.cpu().numpy(),
                                             global_step)
                    val_score = eval_net(net, val_loader, device)
                    scheduler.step(val_score)
                    writer.add_scalar('learning_rate',
                                      optimizer.param_groups[0]['lr'],
                                      global_step)

                    if net.n_classes > 1:
                        logging.info(
                            'Validation cross entropy: {}'.format(val_score))
                        writer.add_scalar('Loss/test', val_score, global_step)
                    else:
                        logging.info(
                            'Validation Dice Coeff: {}'.format(val_score))
                        writer.add_scalar('Dice/test', val_score, global_step)

                    writer.add_images('images', imgs, global_step)
                    if net.n_classes == 1:
                        writer.add_images('masks/true', true_masks,
                                          global_step)
                        writer.add_images('masks/pred',
                                          torch.sigmoid(masks_pred) > 0.5,
                                          global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')

    writer.close()
Пример #27
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5,
              data_augment=True):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=0,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=0,
                            pin_memory=True)

    global_step = 0

    logging.info(f'''Starting training:
    Epochs:          {epochs}
    Batch size:      {batch_size}
    Learning rate:   {lr}
    Training size:   {n_train}
    Validation size: {n_val}
    Checkpoints:     {save_cp}
    Device:          {device.type}
    Images scaling:  {img_scale}
    ''')

    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
    criterion = nn.BCEWithLogitsLoss()  # 1 class
    best_score = 0.

    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        with tqdm(total=n_train,
                  desc=f'Epoch {epoch + 1}/{epochs}',
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                assert true_masks.shape[1] == net.n_classes, \
                    f'Network has been defined with {net.n_classes} output classes, ' \
                    f'but loaded masks have {true_masks.shape[1]} channels. Please check that ' \
                    'the masks are loaded correctly.'

                if data_augment:
                    for i in range(imgs.__len__()):
                        imgs[i], true_masks[i] = my_segmentation_transforms(
                            imgs[i], true_masks[i])

                imgs = imgs.to(device=device, dtype=torch.float32)
                true_masks = true_masks.to(device=device, dtype=torch.float32)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()

                pbar.set_postfix(**{'loss (batch)': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                pbar.update(imgs.shape[0])
                global_step += 1

                if global_step % (len(dataset) // (10 * batch_size)) == 0:
                    val_score = eval_net(net, val_loader, device, n_val)
                    logging.info('Validation Dice Coeff: {}'.format(val_score))
                    print(" ")
                    print('Validation Dice Coeff: {}'.format(val_score))

        if best_score < val_score:
            torch.save(net.state_dict(), 'BEST.pth')
            logging.info(f'Best saved !')
            best_score = val_score

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            torch.save(net.state_dict(),
                       dir_checkpoint + f'CP_epoch{epoch + 1}.pth')
            logging.info(f'Checkpoint {epoch + 1} saved !')
Пример #28
0
def main():
    # network = 'deeplabv3p'
    # save_model_path = "./model_weights/" + network + "_"
    # model_path = "./model_weights/" + network + "_0_6000"
    data_dir = ''
    val_percent = .1

    epochs = 9

    kwargs = {
        'num_workers': 4,
        'pin_memory': True
    } if torch.cuda.is_available() else {}
    training_dataset = LaneDataset(
        "~/workspace/myDL/CV/week8/Lane_Segmentation_pytorch/data_list/train.csv",
        transform=transforms.Compose(
            [ImageAug(),
             DeformAug(),
             ScaleAug(),
             CutOut(32, 0.5),
             ToTensor()]))

    training_data_batch = DataLoader(training_dataset,
                                     batch_size=2,
                                     shuffle=True,
                                     drop_last=True,
                                     **kwargs)

    dataset = BasicDataset(data_dir,
                           img_size=cfg.IMG_SIZE,
                           crop_offset=cfg.crop_offset)

    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train,
                              batch_size=cfg.batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=cfg.batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    model = unet_base(cfg.num_classes, cfg.IMG_SIZE)
    model.cuda()

    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=cfg.base_lr,
                                 betas=(0.9, 0.99))

    bce_criterion = nn.BCEWithLogitsLoss()
    dice_criterion = MulticlassDiceLoss()

    model.train()
    epoch_loss = 0

    dataprocess = tqdm(training_data_batch)
    for batch_item in dataprocess:
        image, mask = batch_item['image'], batch_item['mask']
        if torch.cuda.is_available():
            image, mask = image.cuda(), mask.cuda()
            image = image.to(torch.float32).requires_grad_()
            mask = mask.to(torch.float32).requires_grad_()

            masks_pred = model(image)
            masks_pred = torch.argmax(masks_pred, dim=1)
            masks_pred = masks_pred.to(torch.float32)
            mask = mask.to(torch.float32)

            # print('mask_pred:', masks_pred)
            # print('mask:', mask)
            loss = bce_criterion(masks_pred, mask) + dice_criterion(
                masks_pred, mask)
            epoch_loss += loss.item()

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
Пример #29
0
def train_net(net,
              device,
              epochs=5,
              batch_size=1,
              lr=0.1,
              val_percent=0.1,
              save_cp=True,
              img_scale=0.5):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=8,
                              pin_memory=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True)

    #writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{img_scale}')

    logging.info(f'''Starting training:
        Epochs:          {epochs}
        Batch size:      {batch_size}
        Learning rate:   {lr}
        Training size:   {n_train}
        Validation size: {n_val}
        Checkpoints:     {save_cp}
        Device:          {device.type}
        Images scaling:  {img_scale}
    ''')

    optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8)
    if net.n_classes > 1:
        criterion = nn.CrossEntropyLoss()
    else:
        criterion = nn.BCEWithLogitsLoss()
    best_score = 0
    for epoch in range(epochs):
        net.train()
        epoch_loss = 0
        start = time.time()
        with tqdm(total=n_train, desc=f'Epoch {epoch}', unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels, \
                    f'Network has been defined with {net.n_channels} input channels, ' \
                    f'but loaded images have {imgs.shape[1]} channels. Please check that ' \
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                #writer.add_scalar('Loss/train', loss.item(), global_step)

                pbar.set_postfix(**{'loss': loss.item()})

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
                pbar.update(imgs.shape[0])
        cost_time = time.time() - start
        logging.info(f"{epoch} loss: {epoch_loss:.5f} time {cost_time:.3f}s")
        val_score = eval_net(net, val_loader, device, n_val)
        if net.n_classes > 1:
            logging.info('Validation cross entropy: {:.5f}'.format(val_score))
            #writer.add_scalar('Loss/test', val_score, global_step)
        else:
            logging.info('Validation Dice Coeff: {:.5f}'.format(val_score))
            #writer.add_scalar('Dice/test', val_score, global_step)
            #writer.add_images('images', imgs, global_step)
            # if net.n_classes == 1:
            #     writer.add_images('masks/true', true_masks, global_step)
            #     writer.add_images('masks/pred', torch.sigmoid(masks_pred) > 0.5, global_step)
        if val_score > best_score:
            torch.save(net.state_dict(), log_dir + '/best.pth')
            best_score = val_score
            logging.info(f'best improved to {val_score:.5f}')
        torch.save(net.state_dict(), log_dir + "/latest.pth")
Пример #30
0
def train_net(net,
              device,
              epochs=100,
              batch_size=1,
              lr=0.1,
              val_percent=0.2,
              save_cp=True,
              img_scale=1):

    dataset = BasicDataset(dir_img, dir_mask, img_scale)
    n_val = int(len(dataset) * val_percent)
    n_train = len(dataset) - n_val
    train, val = random_split(dataset, [n_train, n_val])
    train_loader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=16,
                              pin_memory=True,
                              drop_last=True)
    val_loader = DataLoader(val,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=8,
                            pin_memory=True,
                            drop_last=True)

    gene_eval_data(val_loader, dir='./data/val/')

    writer = SummaryWriter(
        comment='LR_{}_BS_{}_SCALE_{}'.format(lr, batch_size, img_scale))
    global_step = 0

    logging.info('''Starting training:
        Epochs:          {}
        Batch size:      {}
        Learning rate:   {}
        Training size:   {}
        Validation size: {}
        Checkpoints:     {}
        Device:          {}
        Images scaling:  {}
    '''.format(epochs, batch_size, lr, n_train, n_val, save_cp, device.type,
               img_scale))

    # optimizer = optim.RMSprop(net.parameters(), lr=lr, weight_decay=1e-8, momentum=0.9)
    optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        'min' if net.n_classes > 1 else 'max',
        factor=0.5,
        patience=20)

    criterion = dice_loss
    # criterion = nn.BCELoss()

    last_loss = 9999
    last_val_score = 0
    for epoch in range(epochs):
        net.train()

        epoch_loss = 0
        step = 0
        mybatch_size = 4
        with tqdm(total=n_train,
                  desc='Epoch {}/{}'.format(epoch + 1, epochs),
                  unit='img') as pbar:
            for batch in train_loader:
                imgs = batch['image']
                true_masks = batch['mask']
                assert imgs.shape[1] == net.n_channels,\
                    'Network has been defined with {} input channels, '.format(net.n_channels)+\
                'but loaded images have {} channels. Please check that '.format(imgs.shape[1])+\
                    'the images are loaded correctly.'

                imgs = imgs.to(device=device, dtype=torch.float32)
                mask_type = torch.float32 if net.n_classes == 1 else torch.long
                true_masks = true_masks.to(device=device, dtype=mask_type)

                masks_pred = net(imgs)
                loss = criterion(masks_pred, true_masks)
                epoch_loss += loss.item()
                loss.backward()
                nn.utils.clip_grad_value_(net.parameters(), 0.1)
                global_step += 1
                writer.add_scalar('Loss/train', loss.item(), global_step)
                pbar.set_postfix(**{'loss (batch)': loss.item()})
                step += 1
                if step % mybatch_size == 0:

                    optimizer.step()
                    optimizer.zero_grad()
                    step = 0

                pbar.update(imgs.shape[0])


# if global_step % (len(dataset) // ( 2* batch_size)) == 0:
        for tag, value in net.named_parameters():
            tag = tag.replace('.', '/')
            writer.add_histogram('weights/' + tag,
                                 value.data.cpu().numpy(), global_step)
            writer.add_histogram('grads/' + tag,
                                 value.grad.data.cpu().numpy(), global_step)
        val_score = eval_net(net, val_loader, device)
        scheduler.step(val_score)
        writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'],
                          global_step)

        if net.n_classes > 1:
            logging.info('Validation cross entropy: {}'.format(val_score))
            writer.add_scalar('Loss/test', val_score, global_step)
        else:
            logging.info('Train Loss: {}    Validation Dice Coeff: {} '.format(
                epoch_loss / n_train, val_score))
            writer.add_scalar('Dice/test', val_score, global_step)

            writer.add_images('images', imgs, global_step)
            if net.n_classes == 1:
                writer.add_images('masks/true', true_masks, global_step)
                writer.add_images('masks/pred',
                                  torch.sigmoid(masks_pred) > 0.3, global_step)

        if save_cp:
            try:
                os.mkdir(dir_checkpoint)
                logging.info('Created checkpoint directory')
            except OSError:
                pass
            if last_loss > epoch_loss or last_val_score < val_score:
                last_loss = min(last_loss, epoch_loss)
                last_val_score = max(last_val_score, val_score)
                # torch.save(net.state_dict(),
                torch.save(
                    net, dir_checkpoint +
                    'CP_epoch{}Trainloss{}ValDice{}.pt'.format(
                        epoch + 1, epoch_loss / n_train, val_score))
                logging.info('Checkpoint {} saved !'.format(epoch + 1) +
                             '   CP_epoch{}Trainloss{}ValDice{}.pt'.format(
                                 epoch + 1, epoch_loss / n_train, val_score))

    writer.close()