def val(net, dataloader_):
    global highest_iou
    net.eval()
    iou_meter_val = AverageValueMeter()
    loss_meter_val = AverageValueMeter()
    iou_meter_val.reset()
    for i, (img, mask, _) in tqdm(enumerate(dataloader_)):
        (img, mask) = (img.cuda(), mask.cuda()) if (torch.cuda.is_available()
                                                    and use_cuda) else (img,
                                                                        mask)
        pred_val = net(img)
        loss_val = criterion(pred_val, mask.squeeze(1))
        loss_meter_val.add(loss_val.item())
        iou_val = iou_loss(pred2segmentation(pred_val),
                           mask.squeeze(1).float(), class_number)[1]
        iou_meter_val.add(iou_val)
        if i % val_print_frequncy == 0:
            showImages(board_val_image, img, mask, pred2segmentation(pred_val))

    board_loss.plot('val_iou_per_epoch', iou_meter_val.value()[0])
    board_loss.plot('val_loss_per_epoch', loss_meter_val.value()[0])
    net.train()
    if highest_iou < iou_meter_val.value()[0]:
        highest_iou = iou_meter_val.value()[0]
        torch.save(
            net.state_dict(), 'checkpoint/modified_ENet_%.3f_%s.pth' %
            (iou_meter_val.value()[0], 'equal_' + str(Equalize)))
        print('The highest IOU is:%.3f' % iou_meter_val.value()[0],
              'Model saved.')
def train():
    net.train()
    iou_meter = AverageValueMeter()
    loss_meter = AverageValueMeter()
    for epoch in range(max_epoch):
        iou_meter.reset()
        loss_meter.reset()
        if epoch % 5 == 0:
            for param_group in optimiser.param_groups:
                param_group['lr'] = param_group['lr'] * (0.95**(epoch // 10))
                print('learning rate:', param_group['lr'])

        for i, (img, mask, _) in tqdm(enumerate(train_loader)):
            (img, mask) = (img.cuda(),
                           mask.cuda()) if (torch.cuda.is_available()
                                            and use_cuda) else (img, mask)
            optimiser.zero_grad()
            pred = net(img)
            loss = criterion(pred, mask.squeeze(1))
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(net.parameters(), 1e-3)
            optimiser.step()
            loss_meter.add(loss.item())
            iou = iou_loss(pred2segmentation(pred),
                           mask.squeeze(1).float(), class_number)[1]
            loss_meter.add(loss.item())
            iou_meter.add(iou)

            if i % train_print_frequncy == 0:
                showImages(board_train_image, img, mask,
                           pred2segmentation(pred))

        board_loss.plot('train_iou_per_epoch', iou_meter.value()[0])
        board_loss.plot('train_loss_per_epoch', loss_meter.value()[0])

        val(net, val_loader)
def evaluate(args, net):

    with torch.no_grad():

        img, (l, h) = image_transformation(
            os.path.join(args.img_dir, args.input_name))
        prediction = net(img.unsqueeze(0).to(args.device))
        segmentation = pred2segmentation(
            prediction).cpu().data.numpy().squeeze(0)
        seg_dil = dilate_segmentation(segmentation, args.kernel_size)
        seg_dil = Image.fromarray(np.uint8(seg_dil * 255)).resize((l, h))
        output_name = os.path.basename(args.input_name)
        if not os.path.exists(
                os.path.join(args.out_dir + '__kernel__' +
                             str(args.kernel_size))):
            os.makedirs(
                os.path.join(args.out_dir + '__kernel__' +
                             str(args.kernel_size)))

        seg_dil.save(os.path.join(
            args.out_dir + '__kernel__' + str(args.kernel_size), output_name),
                     'JPEG',
                     optimize=True,
                     progressive=True)
Esempio n. 4
0
    cudnn.benchmark = True

valdata = ISICdata(root=root,model='train',transform=True,dataAugment=False,equalize=Equalize)
val_loader = DataLoader(valdata,batch_size=batch_size,shuffle=False,num_workers=number_workers,pin_memory=True)

iou_meter_val = AverageValueMeter()
iou_crf_meter_val = AverageValueMeter()
iou_meter_val.reset()
iou_crf_meter_val.reset()

net.eval()
plt.ion()
with torch.no_grad():
    for i, (img, mask, (img_path,mask_path)) in tqdm(enumerate(val_loader)):
        (img, mask) = (img.cuda(), mask.cuda()) if (torch.cuda.is_available() and use_cuda) else (img, mask)
        orginal_img = Image.open(os.path.join(valdata.root,'ISIC2018_Task1-2_Training_Input',img_path[0])).resize((384,384))
        pred_val = F.softmax(net(img),dim=1)

        full_prediction = dense_crf(np.array(orginal_img).astype(np.uint8), pred_val[0,1].cpu().data.numpy().astype(np.float32))
        plt.imshow(full_prediction)
        plt.show()

        iou_val = iou_loss(pred2segmentation(pred_val), mask.squeeze(1).float(), class_number)[1]
        iou_crf_val = iou_loss(full_prediction,mask[0,0].cpu().data.numpy(),class_number)[1]
        iou_meter_val.add(iou_val)
        iou_crf_meter_val.add(iou_crf_val)

print(iou_meter_val.value()[0])
print(iou_crf_meter_val.value()[0])