Exemplo n.º 1
0
def main():
    # width_in = 284
    # height_in = 284
    # width_out = 196
    # height_out = 196
    # PATH = './unet.pt'
    # x_train, y_train, x_val, y_val = get_dataset(width_in, height_in, width_out, height_out)
    # print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)

    batch_size = 3
    epochs = 1
    epoch_lapse = 50
    threshold = 0.5
    learning_rate = 0.01
    unet = UNet(in_channel=1, out_channel=2)
    if use_gpu:
        unet = unet.cuda()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.99)
    if sys.argv[1] == 'train':
        train(unet, batch_size, epochs, epoch_lapse, threshold, learning_rate,
              criterion, optimizer, x_train, y_train, x_val, y_val, width_out,
              height_out)
        pass
    else:
        if use_gpu:
            unet.load_state_dict(torch.load(PATH))
        else:
            unet.load_state_dict(torch.load(PATH, map_location='cpu'))
        print(unet.eval())
Exemplo n.º 2
0
def load_model(data, model_path, cuda=True):

    if cuda and not torch.cuda.is_available():
        raise Exception("No GPU found, please run without --cuda")

    unet = UNet()

    if cuda:
        unet = unet.cuda()

    if not cuda:
        unet.load_state_dict(
            torch.load(model_path, map_location=lambda storage, loc: storage))
    else:
        unet.load_state_dict(torch.load(model_path))

    if cuda:
        data = Variable(data.cuda())
    else:
        data = Variable(data)
    data = torch.unsqueeze(data, 0)

    output = unet(data)
    if cuda:
        output = output.cuda()

    return output
def define_G(input_nc,
             output_nc,
             ngf,
             norm='batch',
             use_dropout=False,
             gpu_ids=[]):
    netG = None
    use_gpu = len(gpu_ids) > 0
    norm_layer = get_norm_layer(norm_type=norm)

    if use_gpu:
        assert (torch.cuda.is_available())

    #netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids)
    #netG = GeneratorUNet(in_channels=2, out_channels=1).cuda()
    netG = UNet(n_classes=1).cuda()
    if len(gpu_ids) > 0:
        netG.cuda(gpu_ids[0])
    netG.apply(weights_init)
    return netG
Exemplo n.º 4
0
def main():
    global args
    net = UNet(3, 1)
    net.load(opt.ckpt_path)
    loss = Loss('soft_dice_loss')
    torch.cuda.set_device(0)
    net = net.cuda()
    loss = loss.cuda()

    if args.phase == 'train':
        # train
        dataset = NucleiDetector(opt, phase=args.phase)
        train_loader = DataLoader(dataset,
                                  batch_size=opt.batch_size,
                                  shuffle=True,
                                  num_workers=opt.num_workers,
                                  pin_memory=opt.pin_memory)
        lr = opt.lr
        optimizer = torch.optim.Adam(net.parameters(),
                                     lr=lr,
                                     weight_decay=opt.weight_decay)
        previous_loss = None  # haven't run
        for epoch in range(opt.epoch + 1):
            now_loss = train(train_loader, net, loss, epoch, optimizer,
                             opt.model_save_freq, opt.model_save_path)
            if previous_loss is not None and now_loss > previous_loss:
                lr *= opt.lr_decay
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr
                save_lr(net.model_name, opt.lr_save_path, lr)
            previous_loss = now_loss
    elif args.phase == 'val':
        # val phase
        dataset = NucleiDetector(opt, phase='val')
        val_loader = DataLoader(dataset,
                                batch_size=opt.batch_size,
                                shuffle=True,
                                num_workers=opt.num_workers,
                                pin_memory=opt.pin_memory)
        val(val_loader, net, loss)
    else:
        # test phase
        dataset = NucleiDetector(opt, phase='test')
        test_loader = DataLoader(dataset,
                                 batch_size=1,
                                 shuffle=True,
                                 num_workers=opt.num_workers,
                                 pin_memory=opt.pin_memory)
        test(test_loader, net, opt)
Exemplo n.º 5
0
def test(args):
    """
    Test some data from trained UNet
    """
    image = load_test_image(args.test_image)  # 1 c w h
    net = UNet(in_channels=3, out_channels=5)
    if args.cuda:
        net = net.cuda()
        image = image.cuda()
    print('Loading model param from {}'.format(args.model_state_dict))
    net.load_state_dict(torch.load(args.model_state_dict))
    net.eval()

    print('Predicting for {}...'.format(args.test_image))
    ys_pred = net(image)  # 1 ch w h

    colors = []
    with open(args.mask_json_path, 'r', encoding='utf-8') as mask:
        print('Reading mask colors list from {}'.format(args.mask_json_path))
        colors = json.loads(mask.read())
        colors = [tuple(c) for c in colors]
        print('Mask colors: {}'.format(colors))

    ys_pred = ys_pred.cpu().detach().numpy()[0]
    ys_pred[ys_pred < 0.5] = 0
    ys_pred[ys_pred >= 0.5] = 1
    ys_pred = ys_pred.astype(np.int)
    image_w = ys_pred.shape[1]
    image_h = ys_pred.shape[2]
    out_image = np.zeros((image_w, image_h, 3))

    for w in range(image_w):
        for h in range(image_h):
            for ch in range(ys_pred.shape[0]):
                if ys_pred[ch][w][h] == 1:
                    out_image[w][h][0] = colors[ch][0]
                    out_image[w][h][1] = colors[ch][1]
                    out_image[w][h][2] = colors[ch][2]

    out_image = out_image.astype(np.uint8)  # w h c
    out_image = out_image.transpose((1, 0, 2))  # h w c
    out_image = Image.fromarray(out_image)
    out_image.save(args.test_save_path)
    print('Segmentation result has been saved to {}'.format(
        args.test_save_path))
batch_size = 4
lr = 0.001

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

dataset = BasicDataset(dir_img, dir_mask)
n_val = int(len(dataset) * 0.1)
n_train = len(dataset) - n_val
train, val = random_split(dataset, [n_train, n_val])
train_loader = DataLoader(train, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val, batch_size=batch_size, shuffle=False)

# writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{1}')

net = UNet(n_channels=3, n_classes=classes, bilinear=True)
net = net.cuda()

optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min' if net.n_classes > 1 else 'max', patience=2)

criterion = ImgWtLossSoftNLL(classes, epochs).cuda()
criterion_ = nn.CrossEntropyLoss().cuda()

for epoch in range(epochs):
    epoch_loss = 0.0
    for batch in train_loader:
        imgs = batch['image']
        # print(imgs.size())
        true_masks = batch['mask']
        # print(true_masks.size())
Exemplo n.º 7
0
        train_loader, val_loader = get_train_val_loader(
            opt.root_dir,
            batch_size=opt.batch_size,
            val_ratio=0.15,
            shuffle=True,
            num_workers=4,
            pin_memory=False)

        optimizer = optim.Adam(model.parameters(),
                               lr=opt.learning_rate,
                               weight_decay=opt.weight_decay)
        criterion = nn.BCELoss()
        vis = Visualizer(env=opt.env)

        if opt.is_cuda:
            model.cuda()
            criterion.cuda()
            if opt.n_gpu > 1:
                model = nn.DataParallel(model)

        run(model, train_loader, val_loader, criterion, vis)
    else:
        if opt.is_cuda:
            model.cuda()
            if opt.n_gpu > 1:
                model = nn.DataParallel(model)
        test_loader = get_test_loader(batch_size=20,
                                      shuffle=True,
                                      num_workers=opt.num_workers,
                                      pin_memory=opt.pin_memory)
        # load the model and run test
Exemplo n.º 8
0
def train():
    if not os.path.exists('train_model/'):
        os.makedirs('train_model/')
    if not os.path.exists('result/'):
        os.makedirs('result/')

    train_data, dev_data, word2id, id2word, char2id, opts = load_data(
        vars(args))
    model = UNet(opts)

    if args.use_cuda:
        model = model.cuda()

    dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)

    if args.eval:
        print('load model...')
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        model.Evaluate(dev_batches,
                       args.data_path + 'dev_eval.json',
                       answer_file='result/' + args.model_dir.split('/')[-1] +
                       '.answers',
                       drop_file=args.data_path + 'drop.json',
                       dev=args.data_path + 'dev-v2.0.json')
        exit()

    if args.load_model:
        print('load model...')
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        _, F1 = model.Evaluate(dev_batches,
                               args.data_path + 'dev_eval.json',
                               answer_file='result/' +
                               args.model_dir.split('/')[-1] + '.answers',
                               drop_file=args.data_path + 'drop.json',
                               dev=args.data_path + 'dev-v2.0.json')
        best_score = F1
        with open(args.model_dir + '_f1_scores.pkl', 'rb') as f:
            f1_scores = pkl.load(f)
        with open(args.model_dir + '_em_scores.pkl', 'rb') as f:
            em_scores = pkl.load(f)
    else:
        best_score = 0.0
        f1_scores = []
        em_scores = []

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adamax(parameters, lr=args.lrate)

    lrate = args.lrate

    for epoch in range(1, args.epochs + 1):
        train_batches = get_batches(train_data, args.batch_size)
        dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)
        total_size = len(train_data) // args.batch_size

        model.train()
        for i, train_batch in enumerate(train_batches):
            loss = model(train_batch)
            model.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters, opts['grad_clipping'])
            optimizer.step()
            model.reset_parameters()

            if i % 100 == 0:
                print(
                    'Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f'
                    % (epoch, i, total_size, model.train_loss.value, lrate,
                       best_score))
                sys.stdout.flush()

        model.eval()
        exact_match_score, F1 = model.Evaluate(
            dev_batches,
            args.data_path + 'dev_eval.json',
            answer_file='result/' + args.model_dir.split('/')[-1] + '.answers',
            drop_file=args.data_path + 'drop.json',
            dev=args.data_path + 'dev-v2.0.json')
        f1_scores.append(F1)
        em_scores.append(exact_match_score)
        with open(args.model_dir + '_f1_scores.pkl', 'wb') as f:
            pkl.dump(f1_scores, f)
        with open(args.model_dir + '_em_scores.pkl', 'wb') as f:
            pkl.dump(em_scores, f)

        if best_score < F1:
            best_score = F1
            print('saving %s ...' % args.model_dir)
            torch.save(model.state_dict(), args.model_dir)
        if epoch > 0 and epoch % args.decay_period == 0:
            lrate *= args.decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lrate
Exemplo n.º 9
0
import random
from torch.autograd import Variable
import torch
import SimpleITK as sitk
import nrrd
import numpy as np
from model import UNet
from metrics import dice_score, pixelwise_acc, iou, sen_score, DiceLoss
from loss import CB_loss

torch.manual_seed(10)

# os.environ['OMP_NUM_THREADS']='1'
# os.environ['CUDA_VISIBLE_DEVICES']='1'
unet = UNet(n_channels=1, n_classes=1)
unet = unet.cuda()

print(unet)


def normalize(x):
    #     mean = np.mean(x)
    #     std = np.std(x)
    #     x = (x-mean)/std
    #     x = np.max(x)
    #     x = np.min(x)
    M = np.max(x)
    N = np.min(x)
    X = (x - N) / (M - N)
    return x
Exemplo n.º 10
0
#  Code testing config
#  num_points_fetch = 10
#  train_num_pts = 5
#  n_epochs = 4

train_num_pts = 4800
num_points_fetch = -1
n_epochs = 40

# Pick model
train_on_gpu = True
if LOSS_NUM in [5, 6]:
    model = UNet(n_channels=3, n_classes=4, flag=1).float()
else:
    model = UNet(n_channels=3, n_classes=4).float()
model = model.cuda()
summary(model, (3, 140, 210))

# Print parameter choices
print("Learning rate: " + str(LR))
print("Augmentation: " + str(AUG))
print(model_string)

# Pick loss
if LOSS_NUM == 0:
    print("Using Binary cross entropy")
    criterion = nn.BCELoss()
elif LOSS_NUM == 1:
    print("Using Dice loss")
    criterion = dice_pytorch
elif LOSS_NUM == 2:
Exemplo n.º 11
0
    netG.eval()
    p = 0
    f_path = '/n/holyscratch01/wadduwage_lab/uom_bme/dataset_static_2020/20200105_synthBeads_1/tr_data_1sls/'    
    for line in img_dir:
        print(line)
        GT_ = io.imread(f_path + str(line[0:-1]) + '_gt.png')
        modalities = np.zeros((32,128,128))
        for i in range(0,32):
             modalities[i,:,:] = io.imread(f_path + str(line[0:-1]) +'_'+str(i+1) +'.png')  
        depth = modalities.shape[2]
        predicted_im = np.zeros((128,128,1))
        if np.min(np.array(GT_))==np.max(np.array(GT_)):
             print('Yes')
        GT = torch.from_numpy(np.divide(GT_,max_gt))
        img = torch.from_numpy(np.divide(modalities,max_im)[None, :, :]).float()
        netG = netG.cuda()
        input = img.cuda()
        out = netG(input)
        out = out.cpu()
        out_img = out.data[0]
        out_img = np.squeeze(out_img)
        GT = np.squeeze(GT)
        predict_path= 'Predicted_mse/epoch_' + str(epochs) +'/'
        if not os.path.exists(predict_path):
            os.makedirs(predict_path)
        imsave(predict_path + '/' + str(line[0:-1]) + '_pred.png',out_img)
        imsave(predict_path + '/' + str(line[0:-1]) + '_gt.png',(GT))
print('mse=',torch.div(avg_mse,p))
print(avg_mse)
print(avg_psnr)
print(p)
Exemplo n.º 12
0
class Trainer():

	def __init__(self,config,trainLoader,validLoader):
		
		self.config = config
		self.trainLoader = trainLoader
		self.validLoader = validLoader
		

		self.numTrain = len(self.trainLoader.dataset)
		self.numValid = len(self.validLoader.dataset)
		
		self.saveModelDir = str(self.config.save_model_dir)+"/"
		
		self.bestModel = config.bestModel
		self.useGpu = self.config.use_gpu


		self.net = UNet()


		if(self.config.resume == True):
			print("LOADING SAVED MODEL")
			self.loadCheckpoint()

		else:
			print("INTIALIZING NEW MODEL")

		self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
		self.net = self.net.to(self.device)

	

		self.totalEpochs = config.epochs
		

		self.optimizer = optim.Adam(self.net.parameters(), lr=5e-4)
		self.loss = DiceLoss()

		self.num_params = sum([p.data.nelement() for p in self.net.parameters()])
		
		self.trainPaitence = config.train_paitence
		

		if not self.config.resume:																																																																																																																																																																																		# self.freezeLayers(6)
			summary(self.net, input_size=(3,256,256))
			print('[*] Number of model parameters: {:,}'.format(self.num_params))
			self.writer = SummaryWriter(self.config.tensorboard_path+"/")

		
		
		

	def train(self):
		bestIOU = 0

		print("\n[*] Train on {} sample pairs, validate on {} trials".format(
			self.numTrain, self.numValid))
		

		for epoch in range(0,self.totalEpochs):
			print('\nEpoch: {}/{}'.format(epoch+1, self.totalEpochs))
			
			self.trainOneEpoch(epoch)

			validationIOU = self.validationTest(epoch)

			print("VALIDATION IOU: ",validationIOU)

			# check for improvement
			if(validationIOU > bestIOU):
				print("COUNT RESET !!!")
				bestIOU=validationIOU
				self.counter = 0
				self.saveCheckPoint(
				{
					'epoch': epoch + 1,
					'model_state': self.net.state_dict(),
					'optim_state': self.optimizer.state_dict(),
					'best_valid_acc': bestIOU,
				},True)

			else:
				self.counter += 1
				
			
			if self.counter > self.trainPaitence:
				self.saveCheckPoint(
				{
					'epoch': epoch + 1,
					'model_state': self.net.state_dict(),
					'optim_state': self.optimizer.state_dict(),
					'best_valid_acc': validationIOU,
				},False)
				print("[!] No improvement in a while, stopping training...")
				print("BEST VALIDATION IOU: ",bestIOU)

				return None

		
	def trainOneEpoch(self,epoch):
		self.net.train()
		train_loss = 0
		total_IOU = 0
		
		for batch_idx, (images,targets) in enumerate(self.trainLoader):


			images = images.to(self.device)
			targets = targets.to(self.device)

			
	
			self.optimizer.zero_grad()

			outputMaps = self.net(images)
			
			loss = self.loss(outputMaps,targets)
			

			
			loss.backward()
			self.optimizer.step()

			train_loss += loss.item()

			current_IOU = calc_IOU(outputMaps,targets)
			total_IOU += current_IOU
			
			del(images)
			del(targets)

			progress_bar(batch_idx, len(self.trainLoader), 'Loss: %.3f | IOU: %.3f'
		% (train_loss/(batch_idx+1), current_IOU))
		self.writer.add_scalar('Train/Loss', train_loss/batch_idx+1, epoch)
		self.writer.add_scalar('Train/IOU', total_IOU/batch_idx+1, epoch)
		
		


	def validationTest(self,epoch):
		self.net.eval()
		validationLoss = []
		total_IOU = []
		with torch.no_grad():
			for batch_idx, (images,targets) in enumerate(self.validLoader):
				
				
				
				images = images.to(self.device)
				targets = targets.to(self.device)


				outputMaps = self.net(images)

				loss = self.loss(outputMaps,targets)


				currentValidationLoss = loss.item()
				validationLoss.append(currentValidationLoss)
				current_IOU = calc_IOU(outputMaps,targets)
				total_IOU.append(current_IOU)

			
				# progress_bar(batch_idx, len(self.validLoader), 'Loss: %.3f | IOU: %.3f' % (currentValidationLoss), current_IOU)


				del(images)
				del(targets)

		meanIOU = np.mean(total_IOU)
		meanValidationLoss = np.mean(validationLoss)
		self.writer.add_scalar('Validation/Loss', meanValidationLoss, epoch)
		self.writer.add_scalar('Validation/IOU', meanIOU, epoch)
		
		print("VALIDATION LOSS: ",meanValidationLoss)
				
		
		return meanIOU



	def test(self,dataLoader):

		self.net.eval()
		testLoss = []
		total_IOU = []

		total_outputs_maps = []
		total_input_images = []
		
		with torch.no_grad():
			for batch_idx, (images,targets) in enumerate(dataLoader):

				images = images.to(self.device)
				targets = targets.to(self.device)


				outputMaps = self.net(images)

				
				loss = self.loss(outputMaps,targets)

				testLoss.append(loss.item())
				current_IOU = calc_IOU(outputMaps,targets)
				
				total_IOU.append(current_IOU)
				
				total_outputs_maps.append(outputMaps.cpu().detach().numpy())


				# total_input_images.append(transforms.ToPILImage()(images))
				
				total_input_images.append(images.cpu().detach().numpy())

				del(images)
				del(targets)
				break

		meanIOU = np.mean(total_IOU)
		meanLoss = np.mean(testLoss)
		print("TEST IOU: ",meanIOU)
		print("TEST LOSS: ",meanLoss)	

		return total_input_images,total_outputs_maps
		

		
	def saveCheckPoint(self,state,isBest):
		filename = "model.pth"
		ckpt_path = os.path.join(self.saveModelDir, filename)
		torch.save(state, ckpt_path)
		
		if isBest:
			filename = "best_model.pth"
			shutil.copyfile(ckpt_path, os.path.join(self.saveModelDir, filename))

	def loadCheckpoint(self):

		print("[*] Loading model from {}".format(self.saveModelDir))
		if(self.bestModel):
			print("LOADING BEST MODEL")

			filename = "best_model.pth"

		else:
			filename = "model.pth"

		ckpt_path = os.path.join(self.saveModelDir, filename)
		print(ckpt_path)
		
		if(self.useGpu==False):
			self.net=torch.load(ckpt_path, map_location=lambda storage, loc: storage)


			

		else:
			print("*"*40+" LOADING MODEL FROM GPU "+"*"*40)
			self.ckpt = torch.load(ckpt_path)
			self.net.load_state_dict(self.ckpt['model_state'])

			self.net.cuda()
Exemplo n.º 13
0
class Trainer(object):
    """
    """
    def __init__(self):
        torch.set_num_threads(4)
        self.n_epochs = 10
        self.batch_size = 1
        self.patch_size = 384
        self.is_augment = False
        self.cuda = torch.cuda.is_available()
        self.__build_model()

    def __build_model(self):
        self.model = UNet(1, 1, base=16)
        if self.cuda:
            self.model = self.model.cuda()

    def __reshapetensor(self, tensor, itype='image'):
        if itype == 'image':
            d0, d1, d2, d3, d4 = tensor.size()
            tensor = tensor.view(d0 * d1, d2, d3, d4)
        else:
            d0, d1, d2, d3 = tensor.size()
            tensor = tensor.view(d0 * d1, d2, d3)

        return tensor

    def __get_optimizer(self, **params):
        opt_params = {
            'params': self.model.parameters(),
            'lr': 1e-2,
            'weight_decay': 1e-5
        }
        self.optimizer = RAdam(**opt_params)

        # self.scheduler = None
        self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer,
                                                              'max',
                                                              factor=0.5,
                                                              patience=10,
                                                              verbose=True,
                                                              min_lr=1e-5)

    def run(self, trainset, model_dir):
        """
        """
        print('=' * 100)
        print('Trainning model')
        print('=' * 100)
        if not os.path.exists(model_dir):
            os.mkdir(model_dir)

        model_path = os.path.join(model_dir, 'model.pth')

        #loss_fn = DiceLoss()
        loss_fn = FocalLoss2d()
        #loss_fn = CombineLoss({'dice':0.5, 'focal':0.5})

        self.__get_optimizer()
        Loss = []
        F1 = []
        for epoch in range(self.n_epochs):

            for ith_batch, data in enumerate(trainset):
                images, labels = [d.cuda()
                                  for d in data] if self.cuda else data
                images = self.__reshapetensor(images, itype='image')
                labels = self.__reshapetensor(labels, itype='label')

                preds = self.model(images)
                loss = loss_fn(preds, labels)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                Loss.append(loss.item())
                preds = torch.sigmoid(preds)
                preds[preds > 0.5] = 1
                preds[preds <= 0.5] = 0
                preds = preds.cpu().detach().numpy().flatten()
                labels = labels.cpu().detach().numpy().flatten()
                f1 = f1_score(labels, preds, average='binary')
                F1.append(f1)

                print('EPOCH : {}-----BATCH : {}-----LOSS : {}-----F1 : {}'.
                      format(epoch, ith_batch, loss.item(), f1))

        torch.save(self.model.state_dict(), model_path)

        return model_path
Exemplo n.º 14
0
    conversions = {29: 0, 76: 1, 150: 2, 179: 3, 226: 4, 255: 5}
    gray = cv2.cvtColor(label_im, cv2.COLOR_RGB2GRAY)
    for k in conversions.keys():
        gray[gray == k] = conversions[k]
    # print(np.unique(gray))
    return gray


if __name__ == '__main__':
    net = UNet(model_dir_path=sys.argv[1], input_channels=3)
    test_model = sys.argv[2]
    image_path = sys.argv[3]
    label_path = sys.argv[4]
    patch = int(sys.argv[5])
    net.load_state_dict(torch.load(test_model))
    net.cuda(device=0)

    image_read = cv2.imread(image_path)
    label_read = cv2.imread(label_path)
    small_patch = patch // 4
    full_i = image_read.shape[0] // small_patch
    full_j = image_read.shape[1] // small_patch
    # full_image = np.empty(shape=(small_patch*full_i+small_patch, full_j*small_patch+small_patch, 3))
    # full_label = np.empty(shape=(small_patch*full_i+small_patch, full_j*small_patch+small_patch))
    # full_pred = np.empty(shape=(small_patch*full_i+small_patch, full_j*small_patch+small_patch))
    x, y = image_read.shape[0] // 2, image_read.shape[1] // 2
    full_image = np.empty(shape=(x, y, 3))
    full_label = np.empty(shape=(x, y))
    full_pred = np.empty(shape=(x, y))
    print(image_read.shape)
Exemplo n.º 15
0
def train():
    if not os.path.exists("train_model/"):
        os.makedirs("train_model/")
    if not os.path.exists("result/"):
        os.makedirs("result/")

    train_data, dev_data, word2id, id2word, char2id, opts = load_data(
        vars(args))
    model = UNet(opts)

    if args.use_cuda:
        model = model.cuda()

    dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)

    if args.eval:
        print("load model...")
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        model.Evaluate(
            dev_batches,
            os.path.join(args.prepro_dir, "dev_eval.json"),
            answer_file="result/" + args.model_dir.split("/")[-1] + ".answers",
            drop_file=os.path.join(args.prepro_dir, "drop.json"),
            dev=args.dev_file,
        )
        exit()

    if args.load_model:
        print("load model...")
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        _, F1 = model.Evaluate(
            dev_batches,
            os.path.join(args.prepro_dir, "dev_eval.json"),
            answer_file=os.path.join("result/",
                                     args.model_dir.split("/")[-1],
                                     ".answers"),
            drop_file=os.path.join(args.prepro_dir, "drop.json"),
            dev=args.dev_file,
        )
        best_score = F1
        with open(args.model_dir + "_f1_scores.pkl", "rb") as f:
            f1_scores = pkl.load(f)
        with open(args.model_dir + "_em_scores.pkl", "rb") as f:
            em_scores = pkl.load(f)
    else:
        best_score = 0.0
        f1_scores = []
        em_scores = []

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adamax(parameters, lr=args.lrate)

    lrate = args.lrate

    for epoch in range(1, args.epochs + 1):
        train_batches = get_batches(train_data, args.batch_size)
        dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)
        total_size = len(train_data) // args.batch_size

        model.train()
        for i, train_batch in enumerate(train_batches):
            loss = model(train_batch)
            model.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters, opts["grad_clipping"])
            optimizer.step()
            model.reset_parameters()

            if i % 100 == 0:
                print(
                    "Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f"
                    % (epoch, i, total_size, model.train_loss.value, lrate,
                       best_score))
                sys.stdout.flush()

        model.eval()
        exact_match_score, F1 = model.Evaluate(
            dev_batches,
            os.path.join(args.prepro_dir, "dev_eval.json"),
            answer_file=os.path.join("result/",
                                     args.model_dir.split("/")[-1],
                                     ".answers"),
            drop_file=os.path.join(args.prepro_dir, "drop.json"),
            dev=args.dev_file,
        )
        f1_scores.append(F1)
        em_scores.append(exact_match_score)
        with open(args.model_dir + "_f1_scores.pkl", "wb") as f:
            pkl.dump(f1_scores, f)
        with open(args.model_dir + "_em_scores.pkl", "wb") as f:
            pkl.dump(em_scores, f)

        if best_score < F1:
            best_score = F1
            print("saving %s ..." % args.model_dir)
            torch.save(model.state_dict(), args.model_dir)
        if epoch > 0 and epoch % args.decay_period == 0:
            lrate *= args.decay
            for param_group in optimizer.param_groups:
                param_group["lr"] = lrate
Exemplo n.º 16
0
def train(args):
    """
    Train UNet from datasets
    """

    # dataset
    print('Reading dataset from {}...'.format(args.dataset_path))
    train_dataset = SSDataset(dataset_path=args.dataset_path, is_train=True)
    val_dataset = SSDataset(dataset_path=args.dataset_path, is_train=False)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False)

    # mask
    with open(args.mask_json_path, 'w', encoding='utf-8') as mask:
        colors = SSDataset.all_colors
        mask.write(json.dumps(colors))
        print('Mask colors list has been saved in {}'.format(
            args.mask_json_path))

    # model
    net = UNet(in_channels=3, out_channels=5)
    if args.cuda:
        net = net.cuda()

    # setting
    lr = args.lr  # 1e-3
    optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = loss_fn

    # run
    train_losses = []
    val_losses = []
    print('Start training...')
    for epoch_idx in range(args.epochs):
        # train
        net.train()
        train_loss = 0
        for batch_idx, batch_data in enumerate(train_dataloader):
            xs, ys = batch_data
            if args.cuda:
                xs = xs.cuda()
                ys = ys.cuda()
            ys_pred = net(xs)
            loss = criterion(ys_pred, ys)
            train_loss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # val
        net.eval()
        val_loss = 0
        for batch_idx, batch_data in enumerate(val_dataloader):
            xs, ys = batch_data
            if args.cuda:
                xs = xs.cuda()
                ys = ys.cuda()
            ys_pred = net(xs)
            loss = loss_fn(ys_pred, ys)
            val_loss += loss

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print('Epoch: {}, Train total loss: {}, Val total loss: {}'.format(
            epoch_idx + 1, train_loss.item(), val_loss.item()))

        # save
        if (epoch_idx + 1) % args.save_epoch == 0:
            checkpoint_path = os.path.join(
                args.checkpoint_path,
                'checkpoint_{}.pth'.format(epoch_idx + 1))
            torch.save(net.state_dict(), checkpoint_path)
            print('Saved Checkpoint at Epoch {} to {}'.format(
                epoch_idx + 1, checkpoint_path))

    # summary
    if args.do_save_summary:
        epoch_range = list(range(1, args.epochs + 1))
        plt.plot(epoch_range, train_losses, 'r', label='Train loss')
        plt.plot(epoch_range, val_loss, 'g', label='Val loss')
        plt.imsave(args.summary_image)
        print('Summary images have been saved in {}'.format(
            args.summary_image))

    # save
    net.eval()
    torch.save(net.state_dict(), args.model_state_dict)
    print('Saved state_dict in {}'.format(args.model_state_dict))
Exemplo n.º 17
0
from model import UNet, DNet
import data_loader
from data_loader import *

##############################################################
# Initialise the generator and discriminator with the UNet and
# DNet architectures respectively.
generator = UNet(True)
discriminator = DNet()

##################################################################
# Utilize GPU for performing all the calculations performed in the
# forward and backward passes. Thus allocate all the generator and
# discriminator variables on the default GPU device.
generator.cuda()
discriminator.cuda()

###################################################################
# Create ADAM optimizer for the generator as well the discriminator.
# Create loss criterion for calculating the L1 and adversarial loss.
d_optimizer = optim.Adam(discriminator.parameters(), betas=(0.5, 0.999), lr=0.0002)
g_optimizer = optim.Adam(generator.parameters(), betas=(0.5, 0.999), lr=0.0002)

d_criterion = nn.BCELoss()
g_criterion_1 = nn.BCELoss()
g_criterion_2 = nn.L1Loss()

train_()

def train_():
Exemplo n.º 18
0
def run_inference(args):
    model = UNet(topology=args.model_topology,
                 input_channels=len(args.bands),
                 num_classes=len(args.classes))
    model.load_state_dict(torch.load(args.model_path, map_location='cpu'),
                          strict=False)
    print('Log: Loaded pretrained {}'.format(args.model_path))
    model.eval()
    if args.cuda:
        print('log: Using GPU')
        model.cuda(device=args.device)
    # all_districts = ["abbottabad", "battagram", "buner", "chitral", "hangu", "haripur", "karak", "kohat", "kohistan", "lower_dir", "malakand", "mansehra",
    # "nowshehra", "shangla", "swat", "tor_ghar", "upper_dir"]
    all_districts = ["abbottabad"]

    # years = [2014, 2016, 2017, 2018, 2019, 2020]
    years = [2016]
    # change this to do this for all the images in that directory
    for district in all_districts:
        for year in years:
            print("(LOG): On District: {} @ Year: {}".format(district, year))
            # test_image_path = os.path.join(args.data_path, 'landsat8_4326_30_{}_region_{}.tif'.format(year, district))
            test_image_path = os.path.join(args.data_path,
                                           'landsat8_{}_region_{}.tif'.format(
                                               year, district))  #added(nauman)
            inference_loader, adjustment_mask = get_inference_loader(
                rasterized_shapefiles_path=args.rasterized_shapefiles_path,
                district=district,
                image_path=test_image_path,
                model_input_size=128,
                bands=args.bands,
                num_classes=len(args.classes),
                batch_size=args.bs,
                num_workers=4)
            # inference_loader = get_inference_loader(rasterized_shapefiles_path=args.rasterized_shapefiles_path, district=district,
            #                                                          image_path=test_image_path, model_input_size=128, bands=args.bands,
            #                                                          num_classes=len(args.classes), batch_size=args.bs, num_workers=4)
            # we need to fill our new generated test image
            generated_map = np.empty(
                shape=inference_loader.dataset.get_image_size())
            for idx, data in enumerate(inference_loader):
                coordinates, test_x = data['coordinates'].tolist(
                ), data['input']
                test_x = test_x.cuda(
                    device=args.device) if args.cuda else test_x
                out_x, softmaxed = model.forward(test_x)
                pred = torch.argmax(softmaxed, dim=1)
                pred_numpy = pred.cpu().numpy().transpose(1, 2, 0)
                if idx % 5 == 0:
                    print('LOG: on {} of {}'.format(idx,
                                                    len(inference_loader)))
                for k in range(test_x.shape[0]):
                    x, x_, y, y_ = coordinates[k]
                    generated_map[x:x_, y:y_] = pred_numpy[:, :, k]
            # adjust the inferred map
            generated_map += 1  # to make forest pixels: 2, non-forest pixels: 1, null pixels: 0
            generated_map = np.multiply(generated_map, adjustment_mask)
            # save generated map as png image, not numpy array
            forest_map_rband = np.zeros_like(generated_map)
            forest_map_gband = np.zeros_like(generated_map)
            forest_map_bband = np.zeros_like(generated_map)
            forest_map_gband[generated_map == FOREST_LABEL] = 255
            forest_map_rband[generated_map == NON_FOREST_LABEL] = 255
            forest_map_for_visualization = np.dstack(
                [forest_map_rband, forest_map_gband,
                 forest_map_bband]).astype(np.uint8)
            save_this_map_path = os.path.join(
                args.dest, '{}_{}_inferred_map.png'.format(district, year))
            matimg.imsave(save_this_map_path, forest_map_for_visualization)
            print('Saved: {} @ {}'.format(save_this_map_path,
                                          forest_map_for_visualization.shape))
Exemplo n.º 19
0
    # initial_epoch = 150
    if initial_epoch > 0:
        print('resuming by loading epoch %03d' % initial_epoch)
        u_model.load_state_dict(
            torch.load(os.path.join(save_dir,
                                    'model_%03d.pth' % initial_epoch)))

        # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))

    model.eval()
    u_model.train()
    criterion = nn.MSELoss()

    if cuda:
        model = model.cuda()
        u_model = u_model.cuda()

    optimizer = optim.Adam(model.parameters(), lr=args.lr)
    scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90],
                            gamma=0.2)  # learning rates
    for epoch in range(initial_epoch, n_epoch):

        scheduler.step(epoch)  # step to the learning rate in this epcoh
        xs = dg.datagenerator(data_dir=args.train_data)
        xs = xs.astype('float32') / 255.0
        xs = torch.from_numpy(xs.transpose(
            (0, 3, 1, 2)))  # tensor of the clean patches, NXCXHXW

        DDataset = DenoisingDataset(xs, sigma)
        batch_y, batch_x = DDataset[:238336]
Exemplo n.º 20
0
Arquivo: train.py Projeto: Onojimi/try
if __name__ == '__main__':
    args = get_args()
    #     os.environ["CUDA_VISIBLE_DEVICES"] = '0'
    net = UNet(input_channels=3, nclasses=1)
    writer = SummaryWriter(log_dir='../../log/sn1', comment='unet')
    #     net.cuda()
    #     import pdb
    #     from torchsummary import summary
    #     summary(net, (3,1000,1000))
    #     pdb.set_trace()

    if args.gpu:
        if torch.cuda.device_count() > 1:
            net = nn.DataParallel(net)
        net.cuda()

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  gpu=args.gpu,
                  writer=writer,
                  load=args.load)

        torch.save(net.state_dict(), 'model_fin.pth')

    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'interrupt.pth')
        print('saved interrupt')
Exemplo n.º 21
0
def main():
    params = Params()
    img_dir = params.test['img_dir']
    label_dir = params.test['label_dir']
    save_dir = params.test['save_dir']
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    model_path = params.test['model_path']
    save_flag = params.test['save_flag']
    tta = params.test['tta']

    params.save_params('{:s}/test_params.txt'.format(params.test['save_dir']),
                       test=True)

    # check if it is needed to compute accuracies
    eval_flag = True if label_dir else False
    if eval_flag:
        test_results = dict()
        # recall, precision, F1, dice, iou, haus
        tumor_result = utils.AverageMeter(7)
        lym_result = utils.AverageMeter(7)
        stroma_result = utils.AverageMeter(7)
        all_result = utils.AverageMeter(7)
        conf_matrix = np.zeros((3, 3))

    # data transforms
    test_transform = get_transforms(params.transform['test'])

    model_name = params.model['name']
    if model_name == 'ResUNet34':
        model = ResUNet34(params.model['out_c'],
                          fixed_feature=params.model['fix_params'])
    elif params.model['name'] == 'UNet':
        model = UNet(3, params.model['out_c'])
    else:
        raise NotImplementedError()
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    best_checkpoint = torch.load(model_path)
    model.load_state_dict(best_checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(best_checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    # img_names = ['193-adca-5']
    # total_time = 0.0
    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        if eval_flag:
            label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
            gt = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, params)
        if tta:
            img_hf = img.transpose(Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_vf = img.transpose(Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_hvf = img_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_hf = test_transform(
                (img_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_vf = test_transform(
                (img_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_hvf = test_transform((img_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_hf = get_probmaps(input_hf, model, params)
            prob_maps_vf = get_probmaps(input_vf, model, params)
            prob_maps_hvf = get_probmaps(input_hvf, model, params)

            # re flip
            prob_maps_hf = np.flip(prob_maps_hf, 2)
            prob_maps_vf = np.flip(prob_maps_vf, 1)
            prob_maps_hvf = np.flip(np.flip(prob_maps_hvf, 1), 2)

            # rotation 90 and flips
            img_r90 = img.rotate(90, expand=True)
            img_r90_hf = img_r90.transpose(
                Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_r90_vf = img_r90.transpose(
                Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_r90_hvf = img_r90_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_r90 = test_transform((img_r90, ))[0].unsqueeze(0)
            input_r90_hf = test_transform(
                (img_r90_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_r90_vf = test_transform(
                (img_r90_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_r90_hvf = test_transform((img_r90_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_r90 = get_probmaps(input_r90, model, params)
            prob_maps_r90_hf = get_probmaps(input_r90_hf, model, params)
            prob_maps_r90_vf = get_probmaps(input_r90_vf, model, params)
            prob_maps_r90_hvf = get_probmaps(input_r90_hvf, model, params)

            # re flip
            prob_maps_r90 = np.rot90(prob_maps_r90, k=3, axes=(1, 2))
            prob_maps_r90_hf = np.rot90(np.flip(prob_maps_r90_hf, 2),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_vf = np.rot90(np.flip(prob_maps_r90_vf, 1),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_hvf = np.rot90(np.flip(np.flip(prob_maps_r90_hvf, 1),
                                                 2),
                                         k=3,
                                         axes=(1, 2))

            # utils.show_figures((np.array(img), np.array(img_r90_hvf),
            #                     np.swapaxes(np.swapaxes(prob_maps_r90_hvf, 0, 1), 1, 2)))

            prob_maps = (prob_maps + prob_maps_hf + prob_maps_vf +
                         prob_maps_hvf + prob_maps_r90 + prob_maps_r90_hf +
                         prob_maps_r90_vf + prob_maps_r90_hvf) / 8

        pred = np.argmax(prob_maps, axis=0)  # prediction
        pred_inside = pred.copy()
        pred_inside[pred == 4] = 0  # set contours to background
        pred_nuclei_inside_labeled = measure.label(pred_inside > 0)

        pred_tumor_inside = pred_inside == 1
        pred_lym_inside = pred_inside == 2
        pred_stroma_inside = pred_inside == 3
        pred_3types_inside = pred_tumor_inside + pred_lym_inside * 2 + pred_stroma_inside * 3

        # find the correct class for each segmented nucleus
        N_nuclei = len(np.unique(pred_nuclei_inside_labeled))
        N_class = len(np.unique(pred_3types_inside))
        intersection = np.histogram2d(pred_nuclei_inside_labeled.flatten(),
                                      pred_3types_inside.flatten(),
                                      bins=(N_nuclei, N_class))[0]
        classes = np.argmax(intersection, axis=1)
        tumor_nuclei_indices = np.nonzero(classes == 1)
        lym_nuclei_indices = np.nonzero(classes == 2)
        stroma_nuclei_indices = np.nonzero(classes == 3)

        # solve the problem of one nucleus assigned with different labels
        pred_tumor_inside = np.isin(pred_nuclei_inside_labeled,
                                    tumor_nuclei_indices)
        pred_lym_inside = np.isin(pred_nuclei_inside_labeled,
                                  lym_nuclei_indices)
        pred_stroma_inside = np.isin(pred_nuclei_inside_labeled,
                                     stroma_nuclei_indices)

        # remove small objects
        pred_tumor_inside = morph.remove_small_objects(pred_tumor_inside,
                                                       params.post['min_area'])
        pred_lym_inside = morph.remove_small_objects(pred_lym_inside,
                                                     params.post['min_area'])
        pred_stroma_inside = morph.remove_small_objects(
            pred_stroma_inside, params.post['min_area'])

        # connected component labeling
        pred_tumor_inside_labeled = measure.label(pred_tumor_inside)
        pred_lym_inside_labeled = measure.label(pred_lym_inside)
        pred_stroma_inside_labeled = measure.label(pred_stroma_inside)
        pred_all_inside_labeled = pred_tumor_inside_labeled * 3 \
                                  + (pred_lym_inside_labeled * 3 - 2) * (pred_lym_inside_labeled>0) \
                                  + (pred_stroma_inside_labeled * 3 - 1) * (pred_stroma_inside_labeled>0)

        # dilation
        pred_tumor_labeled = morph.dilation(pred_tumor_inside_labeled,
                                            selem=morph.selem.disk(
                                                params.post['radius']))
        pred_lym_labeled = morph.dilation(pred_lym_inside_labeled,
                                          selem=morph.selem.disk(
                                              params.post['radius']))
        pred_stroma_labeled = morph.dilation(pred_stroma_inside_labeled,
                                             selem=morph.selem.disk(
                                                 params.post['radius']))
        pred_all_labeled = morph.dilation(pred_all_inside_labeled,
                                          selem=morph.selem.disk(
                                              params.post['radius']))

        # utils.show_figures([pred, pred2, pred_labeled])

        if eval_flag:
            print('\tComputing metrics...')
            gt_tumor = (gt % 3 == 0) * gt
            gt_lym = (gt % 3 == 1) * gt
            gt_stroma = (gt % 3 == 2) * gt

            tumor_detect_metrics = utils.accuracy_detection_clas(
                pred_tumor_labeled, gt_tumor, clas_flag=False)
            lym_detect_metrics = utils.accuracy_detection_clas(
                pred_lym_labeled, gt_lym, clas_flag=False)
            stroma_detect_metrics = utils.accuracy_detection_clas(
                pred_stroma_labeled, gt_stroma, clas_flag=False)
            all_detect_metrics = utils.accuracy_detection_clas(
                pred_all_labeled, gt, clas_flag=True)

            tumor_seg_metrics = utils.accuracy_object_level(
                pred_tumor_labeled, gt_tumor, hausdorff_flag=False)
            lym_seg_metrics = utils.accuracy_object_level(pred_lym_labeled,
                                                          gt_lym,
                                                          hausdorff_flag=False)
            stroma_seg_metrics = utils.accuracy_object_level(
                pred_stroma_labeled, gt_stroma, hausdorff_flag=False)
            all_seg_metrics = utils.accuracy_object_level(pred_all_labeled,
                                                          gt,
                                                          hausdorff_flag=True)

            tumor_metrics = [*tumor_detect_metrics[:-1], *tumor_seg_metrics]
            lym_metrics = [*lym_detect_metrics[:-1], *lym_seg_metrics]
            stroma_metrics = [*stroma_detect_metrics[:-1], *stroma_seg_metrics]
            all_metrics = [*all_detect_metrics[:-1], *all_seg_metrics]
            conf_matrix += np.array(all_detect_metrics[-1])

            # save result for each image
            test_results[name] = {
                'tumor': tumor_metrics,
                'lym': lym_metrics,
                'stroma': stroma_metrics,
                'all': all_metrics
            }

            # update the average result
            tumor_result.update(tumor_metrics)
            lym_result.update(lym_metrics)
            stroma_result.update(stroma_metrics)
            all_result.update(all_metrics)

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name),
                        pred.astype(np.uint8) * 50)
            misc.imsave(
                '{:s}/{:s}_prob_tumor.png'.format(prob_maps_folder, name),
                prob_maps[1, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_lym.png'.format(prob_maps_folder, name),
                prob_maps[2, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_stroma.png'.format(prob_maps_folder, name),
                prob_maps[3, :, :])
            # np.save('{:s}/{:s}_prob.npy'.format(prob_maps_folder, name), prob_maps)
            # np.save('{:s}/{:s}_seg.npy'.format(seg_folder, name), pred_all_labeled)
            final_pred = Image.fromarray(pred_all_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored = np.zeros((ori_h, ori_w, 3))
            pred_colored_instance = np.zeros((ori_h, ori_w, 3))
            pred_colored[pred_tumor_labeled > 0] = np.array([255, 0, 0])
            pred_colored[pred_lym_labeled > 0] = np.array([0, 255, 0])
            pred_colored[pred_stroma_labeled > 0] = np.array([0, 0, 255])
            filename = '{:s}/{:s}_seg_colored_3types.png'.format(
                seg_folder, name)
            misc.imsave(filename, pred_colored)
            for k in range(1, pred_all_labeled.max() + 1):
                pred_colored_instance[pred_all_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored_instance)

            # img_overlaid = utils.overlay_edges(label_img, pred_labeled2, img)
            # filename = '{:s}/{:s}_comparison.png'.format(seg_folder, name)
            # misc.imsave(filename, img_overlaid)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    # print('Time: {:4f}'.format(total_time/counter))

    print('=> Processed all {:d} images'.format(counter))
    if eval_flag:
        print(
            'Average: clas_acc\trecall\tprecision\tF1\tdice\tiou\thausdorff\n'
            'tumor: {t[0]:.4f}, {t[1]:.4f}, {t[2]:.4f}, {t[3]:.4f}, {t[4]:.4f}, {t[5]:.4f}, {t[6]:.4f}\n'
            'lym: {l[0]:.4f}, {l[1]:.4f}, {l[2]:.4f}, {l[3]:.4f}, {l[4]:.4f}, {l[5]:.4f}, {l[6]:.4f}\n'
            'stroma: {s[0]:.4f}, {s[1]:.4f}, {s[2]:.4f}, {s[3]:.4f}, {s[4]:.4f}, {s[5]:.4f}, {s[6]:.4f}\n'
            'all: {a[0]:.4f}, {a[1]:.4f}, {a[2]:.4f}, {a[3]:.4f}, {a[4]:.4f}, {a[5]:.4f}, {a[6]:.4f}'
            .format(t=tumor_result.avg,
                    l=lym_result.avg,
                    s=stroma_result.avg,
                    a=all_result.avg))

        header = [
            'clas_acc', 'recall', 'precision', 'F1', 'Dice', 'IoU', 'Hausdorff'
        ]
        save_results(header, tumor_result.avg, lym_result.avg,
                     stroma_result.avg, all_result.avg, test_results,
                     conf_matrix, '{:s}/test_result.txt'.format(save_dir))
Exemplo n.º 22
0
    stage3_boxmodel.load_state_dict(
        torch.load(os.path.join(root_dir, stage3_model_load_dir)))
    logger.info('Stage3_Model loaded from {}'.format(stage3_model_load_dir))

    stage3_pointmodel = UNet(n_channels=1, n_classes=5)  #
    stage3_model_load_dir = "saved_model\stage3_unet_refine_point_mask\Bestmodel_394.pth"
    stage3_pointmodel.load_state_dict(
        torch.load(os.path.join(root_dir, stage3_model_load_dir)))
    logger.info('Stage3_Model loaded from {}'.format(stage3_model_load_dir))

    # stage3_model = SCSE_UNet(n_channels=1, n_classes=2) #
    # stage3_model_load_dir = "saved_model/stage3_scseunet_refine_label/Bestmodel_82.pth"
    # stage3_model.load_state_dict(torch.load(os.path.join(root_dir, stage3_model_load_dir)))
    # logger.info('Stage3_Model loaded from {}'.format(stage3_model_load_dir))

    stage1_model_whole.cuda()
    stage1_model_segm.cuda()
    stage2_model_box.cuda()
    stage3_boxmodel.cuda()
    stage3_pointmodel.cuda()
    cudnn.benchmark = True  # faster convolutions, but more memory

    # pred_a = pred(s1_modelw=stage1_model_whole,
    #       s1_models=stage1_model_segm,
    #       s2_model=stage2_model_box,
    #       stage3_model=stage3_model,
    #       dataLoader=train_loader,
    #       output_dir=train_output_path,
    #       )
    # pred_a.forward()
    pred_b = pred(
Exemplo n.º 23
0
def main():
    global params, best_iou, num_iter, tb_writer, logger, logger_results
    best_iou = 0
    params = Params()
    params.save_params('{:s}/params.txt'.format(params.paths['save_dir']))
    tb_writer = SummaryWriter('{:s}/tb_logs'.format(params.paths['save_dir']))

    os.environ['CUDA_VISIBLE_DEVICES'] = ','.join(
        str(x) for x in params.train['gpu'])

    # set up logger
    logger, logger_results = setup_logging(params)

    # ----- create model ----- #
    model_name = params.model['name']
    if model_name == 'ResUNet34':
        model = ResUNet34(params.model['out_c'],
                          fixed_feature=params.model['fix_params'])
    elif params.model['name'] == 'UNet':
        model = UNet(3, params.model['out_c'])
    else:
        raise NotImplementedError()

    logger.info('Model: {:s}'.format(model_name))
    # if not params.train['checkpoint']:
    #     logger.info(model)
    model = nn.DataParallel(model)
    model = model.cuda()
    global vgg_model
    logger.info('=> Using VGG16 for perceptual loss...')
    vgg_model = vgg16_feat()
    vgg_model = nn.DataParallel(vgg_model).cuda()
    cudnn.benchmark = True

    # ----- define optimizer ----- #
    optimizer = torch.optim.Adam(model.parameters(),
                                 params.train['lr'],
                                 betas=(0.9, 0.99),
                                 weight_decay=params.train['weight_decay'])

    # ----- get pixel weights and define criterion ----- #
    if not params.train['weight_map']:
        criterion = torch.nn.NLLLoss().cuda()
    else:
        logger.info('=> Using weight maps...')
        criterion = torch.nn.NLLLoss(reduction='none').cuda()

    if params.train['beta'] > 0:
        logger.info('=> Using perceptual loss...')
        global criterion_perceptual
        criterion_perceptual = perceptual_loss()

    data_transforms = {
        'train': get_transforms(params.transform['train']),
        'val': get_transforms(params.transform['val'])
    }

    # ----- load data ----- #
    dsets = {}
    for x in ['train', 'val']:
        img_dir = '{:s}/{:s}'.format(params.paths['img_dir'], x)
        target_dir = '{:s}/{:s}'.format(params.paths['label_dir'], x)
        if params.train['weight_map']:
            weight_map_dir = '{:s}/{:s}'.format(params.paths['weight_map_dir'],
                                                x)
            dir_list = [img_dir, weight_map_dir, target_dir]
            postfix = ['weight.png', 'label_with_contours.png']
            num_channels = [3, 1, 3]
        else:
            dir_list = [img_dir, target_dir]
            postfix = ['label_with_contours.png']
            num_channels = [3, 3]
        dsets[x] = DataFolder(dir_list, postfix, num_channels,
                              data_transforms[x])
    train_loader = DataLoader(dsets['train'],
                              batch_size=params.train['batch_size'],
                              shuffle=True,
                              num_workers=params.train['workers'])
    val_loader = DataLoader(dsets['val'],
                            batch_size=params.train['val_batch_size'],
                            shuffle=False,
                            num_workers=params.train['workers'])

    # ----- optionally load from a checkpoint for validation or resuming training ----- #
    if params.train['checkpoint']:
        if os.path.isfile(params.train['checkpoint']):
            logger.info("=> loading checkpoint '{}'".format(
                params.train['checkpoint']))
            checkpoint = torch.load(params.train['checkpoint'])
            params.train['start_epoch'] = checkpoint['epoch']
            best_iou = checkpoint['best_iou']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            logger.info("=> loaded checkpoint '{}' (epoch {})".format(
                params.train['checkpoint'], checkpoint['epoch']))
        else:
            logger.info("=> no checkpoint found at '{}'".format(
                params.train['checkpoint']))

    # ----- training and validation ----- #
    num_iter = params.train['num_epochs'] * len(train_loader)

    # print training parameters
    logger.info("=> Initial learning rate: {:g}".format(params.train['lr']))
    logger.info("=> Batch size: {:d}".format(params.train['batch_size']))
    # logger.info("=> Number of training iterations: {:d}".format(num_iter))
    logger.info("=> Training epochs: {:d}".format(params.train['num_epochs']))
    logger.info("=> beta: {:.1f}".format(params.train['beta']))

    for epoch in range(params.train['start_epoch'],
                       params.train['num_epochs']):
        # train for one epoch or len(train_loader) iterations
        logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1,
                                                params.train['num_epochs']))
        train_results = train(train_loader, model, optimizer, criterion, epoch)
        train_loss, train_loss_ce, train_loss_var, train_iou_nuclei, train_iou = train_results

        # evaluate on validation set
        with torch.no_grad():
            val_results = validate(val_loader, model, criterion)
            val_loss, val_loss_ce, val_loss_var, val_iou_nuclei, val_iou = val_results

        # check if it is the best accuracy
        combined_iou = (val_iou_nuclei + val_iou) / 2
        is_best = combined_iou > best_iou
        best_iou = max(combined_iou, best_iou)

        cp_flag = (epoch + 1) % params.train['checkpoint_freq'] == 0

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'best_iou': best_iou,
                'optimizer': optimizer.state_dict(),
            }, epoch, is_best, params.paths['save_dir'], cp_flag)

        # save the training results to txt files
        logger_results.info(
            '{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}'
            .format(epoch + 1, train_loss, train_loss_ce, train_loss_var,
                    train_iou_nuclei, train_iou, val_loss, val_iou_nuclei,
                    val_iou))
        # tensorboard logs
        tb_writer.add_scalars(
            'epoch_losses', {
                'train_loss': train_loss,
                'train_loss_ce': train_loss_ce,
                'train_loss_var': train_loss_var,
                'val_loss': val_loss
            }, epoch)
        tb_writer.add_scalars(
            'epoch_accuracies', {
                'train_iou_nuclei': train_iou_nuclei,
                'train_iou': train_iou,
                'val_iou_nuclei': val_iou_nuclei,
                'val_iou': val_iou
            }, epoch)
    tb_writer.close()