Exemplo n.º 1
0
from networks.unet import Unet
from networks.dunet import Dunet
from networks.dinknet import LinkNet34, DinkNet34, DinkNet50, DinkNet101, DinkNet34_less_pool
from framework import MyFrame
from loss import dice_bce_loss
from data import ImageFolder

SHAPE = (1024, 1024)
ROOT = 'dataset/train/'
imagelist = filter(lambda x: x.find('sat') != -1, os.listdir(ROOT))
trainlist = map(lambda x: x[:-8], imagelist)
NAME = 'log01_dink34'
BATCHSIZE_PER_CARD = 4

solver = MyFrame(DinkNet34, dice_bce_loss, 2e-4)
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

dataset = ImageFolder(trainlist, ROOT)
data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batchsize,
                                          shuffle=True,
                                          num_workers=4)

mylog = open('logs/' + NAME + '.log', 'w')
tic = time()
no_optim = 0
total_epoch = 300
train_epoch_best_loss = 100.
for epoch in range(1, total_epoch + 1):
    data_loader_iter = iter(data_loader)
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    print("count", Constants.BATCHSIZE_PER_CARD)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD
    print("batchsize", batchsize)

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        torchvision.utils.save_image(img[0, :, :, :],
                                     "test_data/results2/image_" + str(epoch) +
                                     ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(mask[0, :, :, :],
                                     "test_data/results2/mask_" + str(epoch) +
                                     ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(pred[0, :, :, :],
                                     "test_data/results2/pred_" + str(epoch) +
                                     ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)

        # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8)
        # x = show_image[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'RGB')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) +  '.jpg')

        # x = mask[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'L')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) +  '.jpg')

        # x = pred[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy())
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) +  '.jpg')
        # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png")
        # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0)
        # x = mask[0,:,:,:]
        # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png")
        # x = pred[0,:,:,:]
        # print(x.shape)
        # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8'))

        # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png")
        print("saving images")

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        # if train_epoch_loss >= train_epoch_best_loss:
        #     no_optim += 1
        # else:
        #     no_optim = 0
        #     train_epoch_best_loss = train_epoch_loss
        #     solver.save('./weights/' + NAME + '.th')
        # if no_optim > Constants.NUM_EARLY_STOP:
        #     print(mylog, 'early stop at %d epoch' % epoch)
        #     print('early stop at %d epoch' % epoch)
        #     break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            print("loading the weights")
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
Exemplo n.º 3
0
from networks.dinknet import ResNet34_EdgeNet
from framework import MyFrame
from loss import Regularized_Loss
from data import ImageFolder

SHAPE = (512, 512)

sat_dir = '/data/train/sat/'
lab_dir = '/data/train/mask_proposal/'
hed_dir = '/data/train/rough_edge/'
imagelist = os.listdir(lab_dir)
trainlist = map(lambda x: x[:-9], imagelist)
NAME = 'DBNet_0'
BATCHSIZE_PER_CARD = 2
solver = MyFrame(ResNet34_EdgeNet, Regularized_Loss, 2e-4)
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

dataset = ImageFolder(trainlist, sat_dir, lab_dir, hed_dir)
data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batchsize,
                                          shuffle=True,
                                          num_workers=0)

mylog = open('logs/' + NAME + '.log', 'w')
tic = time()
no_optim = 0
total_epoch = 300
train_epoch_best_loss = 100.

for epoch in range(1, total_epoch + 1):
Exemplo n.º 4
0
from framework import MyFrame
from loss import dice_bce_loss
from data_ganx4 import ImageFolder
import pdb

if __name__ == '__main__':

    SHAPE = (1024, 1024)
    ROOT = 'experiment/'
    training_root = 'img_x8/'
    trainlist = os.listdir(training_root)
    NAME = 'gan_pspnet_x8'
    BATCHSIZE_PER_CARD = 6

    #solver = MyFrame(DinkNet34, dice_bce_loss, 2e-3)
    solver = MyFrame(PSPNet, dice_bce_loss, 2e-3)

    batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

    dataset = ImageFolder(trainlist, ROOT)
    a=dataset[1]
    data_loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batchsize,
        shuffle=True,
        num_workers=0)

    mylog = open('NAME+'.log', 'w')
    tic = time()
    no_optim = 0
    total_epoch = 80
Exemplo n.º 5
0
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('./weights/' + NAME + '.th')
        if no_optim > Constants.NUM_EARLY_STOP:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
Exemplo n.º 6
0
import torch.utils.data as data

import os
from time import time
from networks.dinknet import DUNet
from framework import MyFrame
from data import ImageFolder

SHAPE = (256, 256)
ROOT = 'E:/shao_xing/tiny_dataset/new0228/tiny_sat_lab/'
imagelist = filter(lambda x: x.find('sat') != -1, os.listdir(ROOT))
trainlist = map(lambda x: x[:-8], imagelist)
NAME = 'ratio_16'
BATCHSIZE_PER_CARD = 2

solver = MyFrame(DUNet, lr=0.00005)
# solver = MyFrame(Unet, dice_bce_loss, 2e-4)
batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

dataset = ImageFolder(trainlist, ROOT)
data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batchsize,
                                          shuffle=True,
                                          num_workers=0)

mylog = open('log/' + NAME + '_finetune.log', 'a')
tic = time()
no_optim = 0
total_epoch = 100
train_epoch_best_loss = 100.
Exemplo n.º 7
0
def Net3_Train(train_i=0):
    NAME = 'fold'+str(train_i+1)+'_1UNet'
    mylog = open('logs/' + NAME + '.log', 'w')
    print(NAME)
    print(NAME, file=mylog, flush=True)

    txt_train = 'fold'+str(train_i+1)+'_train.csv'
    txt_test = 'fold'+str(train_i+1)+'_test.csv'
    dataset_train = MyDataset(txt_path=txt_train, transform=transforms.ToTensor(), target_transform=transforms.ToTensor())
    dataset_test = MyDataset(txt_path=txt_test, transform=transforms.ToTensor(), target_transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=False, num_workers=2)

    slover = MyFrame(NestedUNet, dice_bce_loss, 2e-4)
    batch_size = 4
    total_epoch = 100
    no_optim = 0
    train_epoch_best_loss = 10000
    best_test_score = 0
    for epoch in range(1, total_epoch+1):
        data_loder_iter = iter(train_loader)
        data_loder_test = iter(test_loader)
        train_epoch_loss = 0
        index = 0

        tic = time()

        train_score = 0
        for img, mask in data_loder_iter:
            slover.set_input(img, mask)
            train_loss, pred = slover.optimize()
            train_score += dice_coeff(mask, pred, False)
            train_epoch_loss +=train_loss 
            index +=1

        test_sen = 0
        test_ppv = 0
        test_score = 0
        for img, mask in data_loder_test:
            slover.set_input(img, mask)
            pre_mask, _ = slover.test_batch()
            test_score += dice_coeff(y_test, pre_mask, False)
            test_sen += sensitive(y_test, pre_mask)
            test_ppv += positivepv(y_test, pre_mask)
        test_sen /= len(data_loder_test)
        test_ppv /= len(data_loder_test)
        test_score /= len(data_loder_test)

        if test_score>best_test_score:
            print('1. the dice score up to ', test_score, 'from ', best_test_score, 'saving the model', file=mylog, flush=True)
            print('1. the dice score up to ', test_score, 'from ', best_test_score, 'saving the model')
            best_test_score = test_score
            slover.save('./weights/'+NAME+'.th')

        train_epoch_loss = train_epoch_loss/len(data_loder_iter)
        train_score = train_score/len(data_loder_iter)
        print('epoch:', epoch, '    time:', int(time() - tic), 'train_loss:', train_epoch_loss.cpu().data.numpy(), 'train_score:', train_score, file=mylog, flush=True)
        print('test_dice_loss: ', test_score, 'test_sen: ', test_sen, 'test_ppv: ', test_ppv, 'best_score is ', best_test_score, file=mylog, flush=True)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim +=1
        else:
            no_optim =0
            train_epoch_best_loss = train_epoch_loss 
    print('Finish!', file=mylog, flush=True)
    print('Finish!')
    mylog.close()