예제 #1
0
def see_results(n_channels, n_classes, load_weights, dir_img, dir_cmp, savedir,
                title):
    # Use GPU or not
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")

    # Create the model
    net = UNet(n_channels, n_classes).to(device)
    net = torch.nn.DataParallel(
        net, device_ids=list(range(torch.cuda.device_count()))).to(device)

    # Load old weights
    checkpoint = torch.load(load_weights, map_location='cpu')
    net.load_state_dict(checkpoint['state_dict'])

    # Load the dataset
    loader = get_dataloader_show(dir_img, dir_cmp)

    # If savedir does not exists make folder
    if not os.path.exists(savedir):
        os.makedirs(savedir)

    net.eval()
    with torch.no_grad():
        for (data, gt) in loader:
            # Use GPU or not
            data, gt = data.to(device), gt.to(device)

            # Forward
            predictions = net(data)

            save_image(predictions, savedir + title + "_pred.png")
            save_image(gt, savedir + title + "_gt.png")
예제 #2
0
def test(path):
    """Count the input image"""
    device = torch.device(
        'cuda') if torch.cuda.is_available() else torch.device('cpu')

    # Image
    image = np.array(Image.open(path), dtype=np.float32) / 255
    image = torch.Tensor(np.transpose(image, (2, 0, 1))).unsqueeze(0)

    # Ground Truth
    header = ".".join(path.split('/')[-1].split('.')[:2])
    label_path = opt.label_path + header + '.label.png'
    label = np.array(Image.open(label_path))
    if opt.color == 'red':
        labels = 100.0 * (label[:, :, 0] > 0)
    else:
        labels = 100.0 * (label[:, :, 1] > 0)
    labels = ndimage.gaussian_filter(labels, sigma=(1, 1), order=0)
    labels = torch.Tensor(labels).unsqueeze(0)

    if opt.model.find("UNet") != -1:
        model = UNet(input_filters=3, filters=opt.unet_filters,
                     N=opt.conv).to(device)
    else:
        model = FCRN_A(input_filters=3, filters=opt.unet_filters,
                       N=opt.conv).to(device)
    model = torch.nn.DataParallel(model)

    if os.path.exists('{}.pth'.format(opt.model)):
        model.load_state_dict(torch.load('{}.pth'.format(opt.model)))

    model.eval()
    image = image.to(device)
    labels = labels.to(device)

    out = model(image)
    predicted_counts = torch.sum(out).item() / 100
    real_counts = torch.sum(labels).item() / 100
    print(predicted_counts, real_counts)

    label = np.zeros((image.shape[2], image.shape[2], 3))
    if opt.color == 'red':
        label[:, :, 0] = out[0][0].cpu().detach().numpy()
    else:
        label[:, :, 1] = out[0][0].cpu().detach().numpy()

    imageio.imwrite('example/test_results/density_map_{}.png'.format(header),
                    label)

    return header, predicted_counts, real_counts
예제 #3
0
def main():
    # width_in = 284
    # height_in = 284
    # width_out = 196
    # height_out = 196
    # PATH = './unet.pt'
    # x_train, y_train, x_val, y_val = get_dataset(width_in, height_in, width_out, height_out)
    # print(x_train.shape, y_train.shape, x_val.shape, y_val.shape)

    batch_size = 3
    epochs = 1
    epoch_lapse = 50
    threshold = 0.5
    learning_rate = 0.01
    unet = UNet(in_channel=1, out_channel=2)
    if use_gpu:
        unet = unet.cuda()
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.99)
    if sys.argv[1] == 'train':
        train(unet, batch_size, epochs, epoch_lapse, threshold, learning_rate,
              criterion, optimizer, x_train, y_train, x_val, y_val, width_out,
              height_out)
        pass
    else:
        if use_gpu:
            unet.load_state_dict(torch.load(PATH))
        else:
            unet.load_state_dict(torch.load(PATH, map_location='cpu'))
        print(unet.eval())
예제 #4
0
def load_model():
	
	checkpoint = get_model()
	
	model = UNet(
		backbone="mobilenetv2",
		num_classes=2,
		pretrained_backbone=None
	)
	
	trained_dict = torch.load(checkpoint, map_location="cpu")['state_dict']
	model.load_state_dict(trained_dict, strict=False)
	model.eval()
	
	print("model is loaded")
	
	return model
예제 #5
0
def main():
    opt = parser.parse_args()
    print(torch.__version__)
    print(opt)

    enc_layers = [1, 2, 2, 4]
    dec_layers = [1, 1, 1, 1]
    number_of_channels = [
        int(8 * 2**i) for i in range(1, 1 + len(enc_layers))
    ]  #[16,32,64,128]
    model = UNet(depth=len(enc_layers),
                 encoder_layers=enc_layers,
                 decoder_layers=dec_layers,
                 number_of_channels=number_of_channels,
                 number_of_outputs=3)
    s = torch.load(os.path.join(opt.models_path, opt.name,
                                opt.name + 'best_model.pth'),
                   map_location='cpu')
    new_state_dict = OrderedDict()
    for k, v in s['model'].state_dict().items():
        name = k[7:]  # remove 'module' word in the beginning of keys.
        new_state_dict[name] = v
    model.load_state_dict(new_state_dict)
    model.eval()

    x = torch.randn(1,
                    4,
                    opt.input_size[0],
                    opt.input_size[1],
                    opt.input_size[2],
                    requires_grad=True)

    register_op("group_norm", group_norm_symbolic, "", 10)

    torch_out = torch.onnx.export(
        model,  # model being run
        [
            x,
        ],  # model input (or a tuple for multiple inputs)
        os.path.join(
            opt.models_path, opt.name, opt.name + ".onnx"
        ),  # where to save the model (can be a file or file-like object)
        export_params=True,
        verbose=
        True,  # store the trained parameter weights inside the model file
        opset_version=10)
예제 #6
0
def test(args):
    """
    Test some data from trained UNet
    """
    image = load_test_image(args.test_image)  # 1 c w h
    net = UNet(in_channels=3, out_channels=5)
    if args.cuda:
        net = net.cuda()
        image = image.cuda()
    print('Loading model param from {}'.format(args.model_state_dict))
    net.load_state_dict(torch.load(args.model_state_dict))
    net.eval()

    print('Predicting for {}...'.format(args.test_image))
    ys_pred = net(image)  # 1 ch w h

    colors = []
    with open(args.mask_json_path, 'r', encoding='utf-8') as mask:
        print('Reading mask colors list from {}'.format(args.mask_json_path))
        colors = json.loads(mask.read())
        colors = [tuple(c) for c in colors]
        print('Mask colors: {}'.format(colors))

    ys_pred = ys_pred.cpu().detach().numpy()[0]
    ys_pred[ys_pred < 0.5] = 0
    ys_pred[ys_pred >= 0.5] = 1
    ys_pred = ys_pred.astype(np.int)
    image_w = ys_pred.shape[1]
    image_h = ys_pred.shape[2]
    out_image = np.zeros((image_w, image_h, 3))

    for w in range(image_w):
        for h in range(image_h):
            for ch in range(ys_pred.shape[0]):
                if ys_pred[ch][w][h] == 1:
                    out_image[w][h][0] = colors[ch][0]
                    out_image[w][h][1] = colors[ch][1]
                    out_image[w][h][2] = colors[ch][2]

    out_image = out_image.astype(np.uint8)  # w h c
    out_image = out_image.transpose((1, 0, 2))  # h w c
    out_image = Image.fromarray(out_image)
    out_image.save(args.test_save_path)
    print('Segmentation result has been saved to {}'.format(
        args.test_save_path))
class modelLoader():
    def __init__(self, model_folder, in_chnl, out_chnl, checkpoint=None):
        self.checkpoint = checkpoint
        self.model_folder = model_folder
        self.in_chnl = in_chnl
        self.out_chnl = out_chnl
        self.model = None

    def set_model(self):
        self.model = UNet(self.in_chnl, self.out_chnl).to(device)
        PATH = "./{}/checkpoint_{}.pth".format(self.model_folder,
                                               self.checkpoint)
        self.model.load_state_dict(torch.load(PATH))
        self.model.eval()

    def model_infos(self):
        print(
            "Checkpoint: {}, Model folder: ./{}, Input channel: {}, Output channel: {}"
            .format(self.checkpoint, self.model_folder, self.in_chnl,
                    self.out_chnl))
예제 #8
0
def test(args):
    model = UNet(n_channels=125, n_classes=10).to(device)
    model.load_state_dict(
        torch.load(args.ckpt % args.num_epochs, map_location='cpu'))
    liver_dataset = LiverDataset('data',
                                 transform=x_transform,
                                 target_transform=y_transform,
                                 train=False)
    dataloaders = DataLoader(liver_dataset, batch_size=1)
    model.eval()
    with torch.no_grad():
        for x, y in dataloaders:
            x = x.permute(0, 2, 1, 3)
            print(x.shape)
            x = x.float()
            label = y.float()
            x = x.to(device)
            label = label.to(device)
            outputs = model(x)
            # print(outputs.shape)
            label = label.squeeze(1)
예제 #9
0
        for g_iter in range(1):
            # generator
            optimizer_G.zero_grad()
            gen_imgs = torch.cat((img, output), 1)

            loss_G = -torch.mean(D(gen_imgs))
            loss_focal = criterion(output, label)
            loss = loss_focal + loss_G

            loss.backward()
            optimizer_G.step()

            train_loss += loss_focal.item() / trainSize

    G.eval(), D.eval()
    with torch.no_grad():
        for batch in val_dataloader:
            img_v, label_v = batch[0].to(device), batch[1].to(device)

            output_v = G(img_v)
            loss = criterion(output_v, label_v)

            val_loss += loss.item() / valSize

    loss_track.append((train_loss, loss_G, loss_D, val_loss))
    torch.save(loss_track, 'checkpoint_GAN/loss.pth')

    print(
        '[{:4d}/{}], tr_ls: {:.5f}, G_ls: {:.5f}, D_ls: {:.5f}, te_ls: {:.5f}'.
        format(epoch + 1, epoch_num, train_loss, loss_G, loss_D, val_loss))
예제 #10
0
class Trainer():

	def __init__(self,config,trainLoader,validLoader):
		
		self.config = config
		self.trainLoader = trainLoader
		self.validLoader = validLoader
		

		self.numTrain = len(self.trainLoader.dataset)
		self.numValid = len(self.validLoader.dataset)
		
		self.saveModelDir = str(self.config.save_model_dir)+"/"
		
		self.bestModel = config.bestModel
		self.useGpu = self.config.use_gpu


		self.net = UNet()


		if(self.config.resume == True):
			print("LOADING SAVED MODEL")
			self.loadCheckpoint()

		else:
			print("INTIALIZING NEW MODEL")

		self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
		self.net = self.net.to(self.device)

	

		self.totalEpochs = config.epochs
		

		self.optimizer = optim.Adam(self.net.parameters(), lr=5e-4)
		self.loss = DiceLoss()

		self.num_params = sum([p.data.nelement() for p in self.net.parameters()])
		
		self.trainPaitence = config.train_paitence
		

		if not self.config.resume:																																																																																																																																																																																		# self.freezeLayers(6)
			summary(self.net, input_size=(3,256,256))
			print('[*] Number of model parameters: {:,}'.format(self.num_params))
			self.writer = SummaryWriter(self.config.tensorboard_path+"/")

		
		
		

	def train(self):
		bestIOU = 0

		print("\n[*] Train on {} sample pairs, validate on {} trials".format(
			self.numTrain, self.numValid))
		

		for epoch in range(0,self.totalEpochs):
			print('\nEpoch: {}/{}'.format(epoch+1, self.totalEpochs))
			
			self.trainOneEpoch(epoch)

			validationIOU = self.validationTest(epoch)

			print("VALIDATION IOU: ",validationIOU)

			# check for improvement
			if(validationIOU > bestIOU):
				print("COUNT RESET !!!")
				bestIOU=validationIOU
				self.counter = 0
				self.saveCheckPoint(
				{
					'epoch': epoch + 1,
					'model_state': self.net.state_dict(),
					'optim_state': self.optimizer.state_dict(),
					'best_valid_acc': bestIOU,
				},True)

			else:
				self.counter += 1
				
			
			if self.counter > self.trainPaitence:
				self.saveCheckPoint(
				{
					'epoch': epoch + 1,
					'model_state': self.net.state_dict(),
					'optim_state': self.optimizer.state_dict(),
					'best_valid_acc': validationIOU,
				},False)
				print("[!] No improvement in a while, stopping training...")
				print("BEST VALIDATION IOU: ",bestIOU)

				return None

		
	def trainOneEpoch(self,epoch):
		self.net.train()
		train_loss = 0
		total_IOU = 0
		
		for batch_idx, (images,targets) in enumerate(self.trainLoader):


			images = images.to(self.device)
			targets = targets.to(self.device)

			
	
			self.optimizer.zero_grad()

			outputMaps = self.net(images)
			
			loss = self.loss(outputMaps,targets)
			

			
			loss.backward()
			self.optimizer.step()

			train_loss += loss.item()

			current_IOU = calc_IOU(outputMaps,targets)
			total_IOU += current_IOU
			
			del(images)
			del(targets)

			progress_bar(batch_idx, len(self.trainLoader), 'Loss: %.3f | IOU: %.3f'
		% (train_loss/(batch_idx+1), current_IOU))
		self.writer.add_scalar('Train/Loss', train_loss/batch_idx+1, epoch)
		self.writer.add_scalar('Train/IOU', total_IOU/batch_idx+1, epoch)
		
		


	def validationTest(self,epoch):
		self.net.eval()
		validationLoss = []
		total_IOU = []
		with torch.no_grad():
			for batch_idx, (images,targets) in enumerate(self.validLoader):
				
				
				
				images = images.to(self.device)
				targets = targets.to(self.device)


				outputMaps = self.net(images)

				loss = self.loss(outputMaps,targets)


				currentValidationLoss = loss.item()
				validationLoss.append(currentValidationLoss)
				current_IOU = calc_IOU(outputMaps,targets)
				total_IOU.append(current_IOU)

			
				# progress_bar(batch_idx, len(self.validLoader), 'Loss: %.3f | IOU: %.3f' % (currentValidationLoss), current_IOU)


				del(images)
				del(targets)

		meanIOU = np.mean(total_IOU)
		meanValidationLoss = np.mean(validationLoss)
		self.writer.add_scalar('Validation/Loss', meanValidationLoss, epoch)
		self.writer.add_scalar('Validation/IOU', meanIOU, epoch)
		
		print("VALIDATION LOSS: ",meanValidationLoss)
				
		
		return meanIOU



	def test(self,dataLoader):

		self.net.eval()
		testLoss = []
		total_IOU = []

		total_outputs_maps = []
		total_input_images = []
		
		with torch.no_grad():
			for batch_idx, (images,targets) in enumerate(dataLoader):

				images = images.to(self.device)
				targets = targets.to(self.device)


				outputMaps = self.net(images)

				
				loss = self.loss(outputMaps,targets)

				testLoss.append(loss.item())
				current_IOU = calc_IOU(outputMaps,targets)
				
				total_IOU.append(current_IOU)
				
				total_outputs_maps.append(outputMaps.cpu().detach().numpy())


				# total_input_images.append(transforms.ToPILImage()(images))
				
				total_input_images.append(images.cpu().detach().numpy())

				del(images)
				del(targets)
				break

		meanIOU = np.mean(total_IOU)
		meanLoss = np.mean(testLoss)
		print("TEST IOU: ",meanIOU)
		print("TEST LOSS: ",meanLoss)	

		return total_input_images,total_outputs_maps
		

		
	def saveCheckPoint(self,state,isBest):
		filename = "model.pth"
		ckpt_path = os.path.join(self.saveModelDir, filename)
		torch.save(state, ckpt_path)
		
		if isBest:
			filename = "best_model.pth"
			shutil.copyfile(ckpt_path, os.path.join(self.saveModelDir, filename))

	def loadCheckpoint(self):

		print("[*] Loading model from {}".format(self.saveModelDir))
		if(self.bestModel):
			print("LOADING BEST MODEL")

			filename = "best_model.pth"

		else:
			filename = "model.pth"

		ckpt_path = os.path.join(self.saveModelDir, filename)
		print(ckpt_path)
		
		if(self.useGpu==False):
			self.net=torch.load(ckpt_path, map_location=lambda storage, loc: storage)


			

		else:
			print("*"*40+" LOADING MODEL FROM GPU "+"*"*40)
			self.ckpt = torch.load(ckpt_path)
			self.net.load_state_dict(self.ckpt['model_state'])

			self.net.cuda()
예제 #11
0
def run_inference(args):
    model = UNet(topology=args.model_topology,
                 input_channels=len(args.bands),
                 num_classes=len(args.classes))
    model.load_state_dict(torch.load(args.model_path, map_location='cpu'),
                          strict=False)
    print('Log: Loaded pretrained {}'.format(args.model_path))
    model.eval()
    if args.cuda:
        print('log: Using GPU')
        model.cuda(device=args.device)
    # all_districts = ["abbottabad", "battagram", "buner", "chitral", "hangu", "haripur", "karak", "kohat", "kohistan", "lower_dir", "malakand", "mansehra",
    # "nowshehra", "shangla", "swat", "tor_ghar", "upper_dir"]
    all_districts = ["abbottabad"]

    # years = [2014, 2016, 2017, 2018, 2019, 2020]
    years = [2016]
    # change this to do this for all the images in that directory
    for district in all_districts:
        for year in years:
            print("(LOG): On District: {} @ Year: {}".format(district, year))
            # test_image_path = os.path.join(args.data_path, 'landsat8_4326_30_{}_region_{}.tif'.format(year, district))
            test_image_path = os.path.join(args.data_path,
                                           'landsat8_{}_region_{}.tif'.format(
                                               year, district))  #added(nauman)
            inference_loader, adjustment_mask = get_inference_loader(
                rasterized_shapefiles_path=args.rasterized_shapefiles_path,
                district=district,
                image_path=test_image_path,
                model_input_size=128,
                bands=args.bands,
                num_classes=len(args.classes),
                batch_size=args.bs,
                num_workers=4)
            # inference_loader = get_inference_loader(rasterized_shapefiles_path=args.rasterized_shapefiles_path, district=district,
            #                                                          image_path=test_image_path, model_input_size=128, bands=args.bands,
            #                                                          num_classes=len(args.classes), batch_size=args.bs, num_workers=4)
            # we need to fill our new generated test image
            generated_map = np.empty(
                shape=inference_loader.dataset.get_image_size())
            for idx, data in enumerate(inference_loader):
                coordinates, test_x = data['coordinates'].tolist(
                ), data['input']
                test_x = test_x.cuda(
                    device=args.device) if args.cuda else test_x
                out_x, softmaxed = model.forward(test_x)
                pred = torch.argmax(softmaxed, dim=1)
                pred_numpy = pred.cpu().numpy().transpose(1, 2, 0)
                if idx % 5 == 0:
                    print('LOG: on {} of {}'.format(idx,
                                                    len(inference_loader)))
                for k in range(test_x.shape[0]):
                    x, x_, y, y_ = coordinates[k]
                    generated_map[x:x_, y:y_] = pred_numpy[:, :, k]
            # adjust the inferred map
            generated_map += 1  # to make forest pixels: 2, non-forest pixels: 1, null pixels: 0
            generated_map = np.multiply(generated_map, adjustment_mask)
            # save generated map as png image, not numpy array
            forest_map_rband = np.zeros_like(generated_map)
            forest_map_gband = np.zeros_like(generated_map)
            forest_map_bband = np.zeros_like(generated_map)
            forest_map_gband[generated_map == FOREST_LABEL] = 255
            forest_map_rband[generated_map == NON_FOREST_LABEL] = 255
            forest_map_for_visualization = np.dstack(
                [forest_map_rband, forest_map_gband,
                 forest_map_bband]).astype(np.uint8)
            save_this_map_path = os.path.join(
                args.dest, '{}_{}_inferred_map.png'.format(district, year))
            matimg.imsave(save_this_map_path, forest_map_for_visualization)
            print('Saved: {} @ {}'.format(save_this_map_path,
                                          forest_map_for_visualization.shape))
예제 #12
0
            output = np.clip(output, a_min = 0, a_max = 1)

            id = num_bach_train * (epoch - 1) + batch

            plt.imsave(os.path.join(result_dir_train, 'png', '%04d_label.png' % id), label[0])
            plt.imsave(os.path.join(result_dir_train, 'png', '%04d_input.png' % id), input[0])
            plt.imsave(os.path.join(result_dir_train, 'png', '%04d_output.png' % id), output[0])

            # writer_train.add_image('label', label, num_batch_train * (epoch - 1) + batch, dataformats = 'NHWC')
            # writer_train.add_image('input', input, num_batch_train * (epoch - 1) + batch, dataformats = 'NHWC')
            # writer_train.add_image('output', output, num_batch_train * (epoch - 1) + batch, dataformats = 'NHWC')

        writer_train.add_scalar('loss', np.mean(loss_arr), epoch)

        with torch.no_grad():
            net.eval()
            loss_arr = []

            for batch, data in enumerate(loader_val, 1):
                # forward pass
                label = data['label'].to(device)
                input = data['input'].to(device)

                output = net(input)

                # 손실함수 계산하기
                loss = fn_loss(output, label)

                loss_arr += [loss.item()]

                print("VALID: EPOCH %04d / %04d | BATCH %04d / %04d | LOSS %.4f" %
예제 #13
0
    def __call__(self, input, target_is_real):
        target_tensor = self.get_target_tensor(input, target_is_real
        return self.loss(input, target_tensor.cuda())

class HDF5Dataset(Dataset):
    def __init__(self,img_dir, isTrain=True):
        self.isTrain = isTrain
        if isTrain: 
            fold_dir = "train.txt"   
        else: 
            fold_dir = "test.txt"

        ids = open(fold_dir, 'r')

        self.index_list = []
        
        for line in ids:
            self.index_list.append(line[0:-1])
        self.img_dir = img_dir
    def __len__(self):
        return len(self.index_list)

    def __getitem__(self, index):
        _img = np.dtype('>u2') 
        _target = np.dtype('>u2') 
        id_ = int(self.index_list[index])
        with h5py.File(self.img_dir, 'r') as db:
             _img = db['input'][id_] 
             _target = db['gt'][id_] 
        if np.max(_target) == 0:
             with h5py.File(self.img_dir, 'r') as db:
                 _img = db['input'][id_+1]
                 _target = db['gt'][id_+1]
        _img = torch.from_numpy(np.divide(_img,max_im)).float()
        _target = torch.from_numpy(np.divide(_target,max_gt)).float()
        
        return _img, _target

class XSigmoidLoss(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, y_t, y_prime_t):
        ey_t = y_t - y_prime_t
        return torch.mean(2 * ey_t / (1 + torch.exp(-ey_t)) - ey_t)


img_dir = '/n/holyscratch01/wadduwage_lab/uom_bme/ForwardModel_matlab/_cnn_synthTrData/03-Jun-2020/cells_tr_data_6sls_03-Jun-2020.h5'
dataset_ = HDF5Dataset(img_dir=img_dir, isTrain=True)
training_data_loader = DataLoader(dataset=dataset_, batch_size=args['batch_size'], shuffle=True, num_workers=0, drop_last=True)

dataset_test = HDF5Dataset(img_dir=img_dir, isTrain=False)
testing_data_loader = DataLoader(dataset=dataset_test, batch_size=args['batch_size'], shuffle=True, num_workers=0, drop_last=True)


netG = UNet(n_classes=args['output_nc']).cuda()
netG = torch.nn.parallel.DataParallel(netG, device_ids=range(args['num_gpus']))
netD = Discriminator().cuda()
netD = torch.nn.parallel.DataParallel(netD, device_ids=range(args['num_gpus']))

criterionGAN = GANLoss().cuda()
criterionL1 = nn.L1Loss().cuda()
criterionMSE = nn.MSELoss().cuda()
criterionxsig = XSigmoidLoss().cuda()
# setup optimizer

optimizerD = optim.Adam(netD.parameters(), lr=args['lr'], betas=(args['beta1'], 0.999))

real_a = torch.FloatTensor(args['batch_size'], args['input_nc'], 128, 128).cuda()
real_b = torch.FloatTensor(args['batch_size'], args['output_nc'], 128, 128).cuda()


real_a = Variable(real_a)
real_b = Variable(real_b)

resume_epoch = 35
def test(args, model, device, test_loader, k_fold, class_weights):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.cross_entropy(output, target, weight = class_weights).item()  # sum up batch loss


    test_loss /= len(test_loader.dataset)
    return test_loss, 100. * correct / len(test_loader.dataset) , report

def train(epoch):
    for iteration, batch in enumerate(training_data_loader, 1):
        
        # forward
        real_a_cpu, real_b_cpu = batch[0], batch[1]
       	real_a.resize_(real_a_cpu.size()).copy_(real_a_cpu)
       	real_b.resize_(real_b_cpu.size()).copy_(real_b_cpu)

        fake_b = netG(real_a)
        #print(fake_b.size())
        ############################
        # (1) Update D network: maximize log(D(x,y)) + log(1 - D(x,G(x)))
        ###########################

        optimizerD.zero_grad()
        
        # train with fake
        fake_ab = torch.cat((real_a, fake_b), 1)
        pred_fake = netD.forward(fake_ab.detach())
        loss_d_fake = criterionGAN(pred_fake, False)
        # train with real

        real_ab = torch.cat((real_a, real_b), 1)
        pred_real = netD.forward(real_ab)
        loss_d_real = criterionGAN(pred_real, True)
        
        # Combined loss
        loss_d = (loss_d_fake + loss_d_real) * 0.5
            
        loss_d.backward()       
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(x,G(x))) + L1(y,G(x))
        ##########################
        optimizerG.zero_grad()
        # First, G(A) should fake the discriminator
        fake_ab = torch.cat((real_a, fake_b), 1)
        pred_fake = netD.forward(fake_ab)
        loss_g_gan = criterionGAN(pred_fake, True)

         # Second, G(A) = B
        loss_g_l1 = criterionL1(fake_b, real_b) * 10

        loss_g = loss_g_gan + loss_g_l1

        loss_g.backward()

        optimizerG.step()
        
        if iteration % 200 == 0: 
            print("===> Epoch[{}]({}/{}): Loss_D: {:.4f} Loss_G: {:.4f}".format(
                epoch, iteration, len(training_data_loader), loss_d.item(), loss_g.item()))
    netG.eval()
    test_loss = 0
    for iteration, batch in enumerate(testing_data_loader, 1):
        real_a_cpu, real_b_cpu = batch[0], batch[1]
       	real_a.resize_(real_a_cpu.size()).copy_(real_a_cpu)
       	real_b.resize_(real_b_cpu.size()).copy_(real_b_cpu)
        fake_b = netG(real_a)
        test_loss += criterionL1(fake_b, real_b).item()
    print(len(testing_data_loader.dataset))
    test_loss /= len(testing_data_loader.dataset)
    print('epoch[{}]: Loss_test: {:.4f}'.format(epoch,test_loss))
           
def checkpoint(epoch):
    if not os.path.exists("checkpoint"):
        os.mkdir("checkpoint")
    if not os.path.exists(os.path.join("checkpoint", args['dataset'])):
        os.mkdir(os.path.join("checkpoint", args['dataset']))
    net_g_model_out_path = "checkpoint/{}/netG_model_epoch_{}.pth.tar".format(args['dataset'], epoch)
    net_d_model_out_path = "checkpoint/{}/netD_model_epoch_{}.pth.tar".format(args['dataset'], epoch)
    torch.save(netG, net_g_model_out_path)
    torch.save(netD, net_d_model_out_path)
    print("Checkpoint saved to {}".format("checkpoint" + args['dataset']))

for epoch in range(1, args['num_epoch'] + 1):
    train(epoch)
    checkpoint(epoch)
예제 #14
0
class Trainer(object):
    """Trainer for training and testing the model"""
    def __init__(self, data_loader, config):
        """Initialize configurations"""

        # model configuration
        self.in_dim = config.in_dim
        self.out_dim = config.out_dim
        self.num_filters = config.num_filters
        self.patch_size = config.patch_size

        # training configuration
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.weight_decay = config.weight_decay
        self.resume_iters = config.resume_iters
        self.mode = config.mode

        # miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.use_cuda = torch.cuda.is_available()
        self.device = torch.device('cuda:{}'.format(config.device_id) \
                                   if self.use_cuda else 'cpu')

        # training result configuration
        self.log_dir = config.log_dir
        self.log_step = config.log_step
        self.model_save_dir = config.model_save_dir
        self.model_save_step = config.model_save_step

        # Build the model and tensorboard.
        self.build_model()
        if self.use_tensorboard:
            self.build_tensorboard()

        # data loader
        if self.mode == 'train' or self.mode == 'test':
            self.data_loader = data_loader
        else:
            self.train_data_loader, self.test_data_loader = data_loader

    def build_model(self):
        """Create a model"""
        self.model = UNet(self.in_dim, self.out_dim, self.num_filters)
        self.model = self.model.float()
        self.optimizer = torch.optim.Adam(self.model.parameters(),
                                          self.lr, [self.beta1, self.beta2],
                                          weight_decay=self.weight_decay)
        self.print_network(self.model, 'unet')
        self.model.to(self.device)

    def _load(self, checkpoint_path):
        if self.use_cuda:
            checkpoint = torch.load(checkpoint_path)
        else:
            checkpoint = torch.load(checkpoint_path,
                                    map_location=lambda storage, loc: storage)
        return checkpoint

    def restore_model(self, resume_iters):
        """Restore the trained model"""

        print(
            'Loading the trained models from step {}...'.format(resume_iters))
        model_path = os.path.join(self.model_save_dir,
                                  '{}-unet'.format(resume_iters) + '.ckpt')
        checkpoint = self._load(model_path)
        self.model.load_state_dict(checkpoint['model'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

    def print_network(self, model, name):
        """Print out the network information"""

        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        #print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def print_optimizer(self, opt, name):
        """Print out optimizer information"""

        print(opt)
        print(name)

    def build_tensorboard(self):
        """Build tensorboard for visualization"""

        from logger import Logger
        self.logger = Logger(self.log_dir)

    def reset_grad(self):
        """Reset the gradient buffers."""

        self.optimizer.zero_grad()

    def train(self):
        """Train model"""
        if self.mode != 'train_test':
            data_loader = self.data_loader
        else:
            data_loader = self.train_data_loader

        print("current dataset size: ", len(data_loader))
        data_iter = iter(data_loader)

        if not os.path.exists(self.model_save_dir):
            os.makedirs(self.model_save_dir)

        # start training from scratch or resume training.
        start_iters = 0
        if self.resume_iters:
            print('Resuming ...')
            start_iters = self.resume_iters
            self.restore_model(self.resume_iters)
            self.print_optimizer(self.optimizer, 'optimizer')

        # print learning rate information
        lr = self.lr
        print('Current learning rates, g_lr: {}.'.format(lr))

        # start training.
        print('Start training...')
        start_time = time.time()
        for i in range(start_iters, self.num_iters):

            # fetch batch data
            try:
                in_data, label = next(data_iter)
            except:
                data_iter = iter(data_loader)
                in_data, label, _, _, _ = next(data_iter)

            in_data = in_data.float().to(self.device)
            label = label.to(self.device)

            # train the model
            self.model = self.model.train()
            y_out = self.model(in_data)
            loss = nn.BCEWithLogitsLoss()
            output = loss(y_out, label)
            self.reset_grad()
            output.backward()
            self.optimizer.step()

            # logging
            if (i + 1) % self.log_step == 0:
                et = time.time() - start_time
                et = str(datetime.timedelta(seconds=et))[:-7]
                log = "Elapsed [{}], Iteration [{}/{}]".format(
                    et, i + 1, self.num_iters)
                log += ", {}: {:.4f}".format("loss", output.mean().item())
                print(log)

                if self.use_tensorboard:
                    self.logger.scalar_summary("loss",
                                               output.mean().item(), i + 1)

            # save model checkpoints
            if (i + 1) % self.model_save_step == 0:
                path = os.path.join(self.model_save_dir,
                                    '{}-unet'.format(i + 1) + '.ckpt')
                torch.save(
                    {
                        'model': self.model.state_dict(),
                        'optimizer': self.optimizer.state_dict()
                    }, path)
                print('Saved model checkpoints into {}...'.format(
                    self.model_save_dir))

    def test(self):
        """Test model"""

        if self.mode != 'train_test':
            data_loader = self.data_loader
        else:
            data_loader = self.test_data_loader
        print("current dataset size: ", len(data_loader))
        data_iter = iter(data_loader)

        # start testing on trained model
        if self.resume_iters and self.mode != 'train_test':
            print('Resuming ...')
            self.restore_model(self.resume_iters)

        # start testing.
        result, trace = np.zeros((78, 110, 24)), np.zeros((78, 110, 24))
        print('Start testing...')
        correct, total, bcorrect = 0, 0, 0
        while (True):

            # fetch batch data
            try:
                data_in, label, i, j, k = next(data_iter)
            except:
                break

            data_in = data_in.float().to(self.device)
            label = label.float().to(self.device)

            # test the model
            self.model = self.model.eval()
            y_hat = self.model(data_in)
            m = nn.Sigmoid()
            y_hat = m(y_hat)
            y_hat = y_hat.squeeze().detach().cpu().numpy()

            label = label.cpu().numpy().astype(int)
            y_hat_th = (y_hat > 0.2)
            label = (label > 0.5)
            test = (label == y_hat_th)
            correct += np.sum(test)
            btest = (label == 0)
            bcorrect += np.sum(btest)
            total += y_hat_th.size

            radius = int(self.patch_size / 2)
            for step in range(self.batch_size):
                x, y, z, pred = i[step], j[step], k[step], np.squeeze(
                    y_hat_th[step, :, :, :])
                result[x - radius:x + radius, y - radius:y + radius,
                       z - radius:z + radius] += pred
                trace[x - radius:x + radius, y - radius:y + radius,
                      z - radius:z + radius] += np.ones(
                          (self.patch_size, self.patch_size, self.patch_size))

        print('Accuracy: %.3f%%' % (correct / total * 100))
        print('Baseline Accuracy: %.3f%%' % (bcorrect / total * 100))

        trace += (trace == 0)
        result = result / trace
        scipy.io.savemat('prediction.mat', {'result': result})

    def train_test(self):
        """Train and test model"""

        self.train()
        self.test()
예제 #15
0
    out = model(torch.from_numpy(img)).cpu().detach().numpy()
    return np.squeeze(out, 0)[0]


app = Flask(__name__)
UPLOAD_FOLDER = "/home/atom/projects/data_science_bowl_2018/src/static"
PRED_PATH = "/home/atom/projects/data_science_bowl_2018/src/static/"
DEVICE = "cpu"


@app.route("/", methods=['GET', 'POST'])
def upload_predict():
    if request.method == "POST":
        image_file = request.files['image']
        if image_file:
            image_location = os.path.join(UPLOAD_FOLDER, image_file.filename)
            image_file.save(image_location)
            pred = predict(image_location, MODEL)
            imsave(PRED_PATH + "pred.png", pred)
            return render_template("index.html",
                                   image_loc=image_file.filename,
                                   pred_loc="pred.png")
    return render_template("index.html", prediction=0, image_loc=None)


if __name__ == "__main__":
    MODEL = UNet()
    MODEL.load_state_dict(torch.load(config.MODEL_LOAD_PATH))
    MODEL.to(DEVICE)
    MODEL.eval()
    app.run(port=12000, debug=True)
예제 #16
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--exp_path',
                        type=list,
                        default=['../../results/model'])
    parser.add_argument('--out_path', type=str, default='.')
    parser.add_argument('--NUMBER_OF_IMAGES', type=int, default=5000)
    parser.add_argument('--NUMBER_OF_PLOTS', type=int, default=5)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--KERNEL_LVL', type=float, default=3)
    parser.add_argument('--NOISE_LVL', type=float, default=1)
    parser.add_argument('--MOTION_BLUR', type=bool, default=True)
    parser.add_argument('--HOMO_ALIGN', type=bool, default=True)
    parser.add_argument('--model_iter', type=int, default=None)
    args = parser.parse_args()

    print()
    print(args)
    print()

    # Evaluation metric parameters
    SSIM_window_size = 3

    dict_ = {}
    for e, exp_path in enumerate(args.exp_paths):

        if args.model_iter == None:
            model_path = get_newest_model(exp_path)
        else:
            model_path = os.path.join(exp_path, args.model_iter)

        model_name = os.path.split(model_path)[1]
        name = str(e) + '_' + model_name.replace('.pt', '')

        dict_[name] = {}
        if not os.path.isdir((os.path.join(args.output_path, name))):
            os.mkdir(os.path.join(args.output_path, name))

        model = UNet(in_channel=3, out_channel=3)

        model.load_state_dict(torch.load(model_path))
        model.eval()

        use_cuda = torch.cuda.is_available()
        device = torch.device("cuda:0" if use_cuda else "cpu")

        model = model.to(device)

        # Parameters
        params = {'batch_size': 1, 'shuffle': True, 'num_workers': 0}

        random.seed(42)
        np.random.seed(42)
        torch.manual_seed(42)

        # Generators
        data_set = Dataset('../../data/test/',
                           max_images=args.NUMBER_OF_IMAGES,
                           kernel_lvl=args.KERNEL_LVL,
                           noise_lvl=args.NOISE_LVL,
                           motion_blur_boolean=args.MOTION_BLUR,
                           homo_align=args.HOMO_ALIGN)
        data_gen = data.DataLoader(data_set, **params)

        # evaluation
        evaluationData = {}

        for i, (X_batch, y_labels) in enumerate(data_gen):
            # Alter the burst length for each mini batch

            burst_length = np.random.randint(
                2,
                9,
            )
            X_batch = X_batch[:, :burst_length, :, :, :]

            # Transfer to GPU
            X_batch, y_labels = X_batch.to(device).type(
                torch.float), y_labels.to(device).type(torch.float)

            with torch.set_grad_enabled(False):
                model.eval()
                pred_batch = model(X_batch)

            evaluationData[str(i)] = {}
            for j in range(params['batch_size']):
                evaluationData[str(i)][str(j)] = {}

                y_label = y_labels[j, :, :, :].detach().cpu().numpy().astype(
                    int)
                pred = pred_batch[j, :, :, :].detach().cpu().numpy().astype(
                    int)

                y_label = np.transpose(y_label, (1, 2, 0))
                pred = np.transpose(pred, (1, 2, 0))
                pred = np.clip(pred, 0, 255)

                if i < args.NUMBER_OF_PLOTS and j == 0:
                    plt.figure(figsize=(20, 5))
                    plt.subplot(1, 2 + len(X_batch[j, :, :, :, :]), 1)
                    plt.imshow(y_label)
                    plt.axis('off')
                    plt.axis('off')
                    plt.title('GT')

                    plt.subplot(1, 2 + len(X_batch[j, :, :, :, :]), 2)
                    plt.imshow(pred)
                    plt.axis('off')
                    plt.title('Pred')

                burst_ssim = []
                burst_psnr = []
                for k in range(len(X_batch[j, :, :, :, :])):
                    x = X_batch[j,
                                k, :, :, :].detach().cpu().numpy().astype(int)
                    burst = np.transpose(x, (1, 2, 0))

                    if i < args.NUMBER_OF_PLOTS and j == 0:
                        plt.subplot(1, 2 + len(X_batch[j, :, :, :, :]), 3 + k)
                        plt.imshow(burst)
                        plt.axis('off')
                        plt.title('Burst ' + str(k))

                    burst_ssim.append(
                        ssim(y_label.astype(float),
                             burst.astype(float),
                             multichannel=True,
                             win_size=SSIM_window_size))
                    burst_psnr.append(psnr(y_label, burst))

                SSIM = ssim(pred.astype(float),
                            y_label.astype(float),
                            multichannel=True,
                            win_size=SSIM_window_size)
                PSNR = psnr(pred, y_label)
                if i < args.NUMBER_OF_PLOTS and j == 0:
                    plt.savefig(os.path.join(args.output_path, name,
                                             str(i) + '.png'),
                                bbox_inches='tight',
                                pad_inches=0)
                    plt.cla()
                    plt.clf()
                    plt.close()

                evaluationData[str(i)][str(j)]['SSIM'] = SSIM
                evaluationData[str(i)][str(j)]['PSNR'] = PSNR
                evaluationData[str(i)][str(j)]['length'] = burst_length
                evaluationData[str(i)][str(j)]['SSIM_burst'] = burst_ssim
                evaluationData[str(i)][str(j)]['PSNR_burst'] = burst_psnr

            if i % 500 == 0 and i > 0:
                print(i)

        #######
        # Save Results
        #######

        x_ssim, y_ssim, y_max_ssim = [], [], []
        x_psnr, y_psnr, y_max_psnr = [], [], []

        for i in evaluationData:
            for j in evaluationData[i]:
                x_ssim.append(evaluationData[i][j]['length'])
                y_ssim.append(evaluationData[i][j]['SSIM'])
                y_max_ssim.append(evaluationData[i][j]['SSIM'] -
                                  max(evaluationData[i][j]['SSIM_burst']))

                x_psnr.append(evaluationData[i][j]['length'])
                y_psnr.append(evaluationData[i][j]['PSNR'])
                y_max_psnr.append(evaluationData[i][j]['PSNR'] -
                                  max(evaluationData[i][j]['PSNR_burst']))

        method = [name] * len(x_ssim)
        dict_[name]['ssim'] = pd.DataFrame(
            np.transpose([x_ssim, y_ssim, y_max_ssim, method]),
            columns=['burst_length', 'ssim', 'max_pred_ssim', 'method'])
        dict_[name]['psnr'] = pd.DataFrame(
            np.transpose([x_psnr, y_psnr, y_max_psnr, method]),
            columns=['burst_length', 'psnr', 'max_pred_psnr', 'method'])

        dict_[name]['ssim'].to_csv(
            os.path.join(args.output_path, 'ssim_' + name + '.csv'))
        dict_[name]['psnr'].to_csv(
            os.path.join(args.output_path, 'psnr_' + name + '.csv'))
예제 #17
0
def train():
    if not os.path.exists('train_model/'):
        os.makedirs('train_model/')
    if not os.path.exists('result/'):
        os.makedirs('result/')

    train_data, dev_data, word2id, id2word, char2id, opts = load_data(
        vars(args))
    model = UNet(opts)

    if args.use_cuda:
        model = model.cuda()

    dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)

    if args.eval:
        print('load model...')
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        model.Evaluate(dev_batches,
                       args.data_path + 'dev_eval.json',
                       answer_file='result/' + args.model_dir.split('/')[-1] +
                       '.answers',
                       drop_file=args.data_path + 'drop.json',
                       dev=args.data_path + 'dev-v2.0.json')
        exit()

    if args.load_model:
        print('load model...')
        model.load_state_dict(torch.load(args.model_dir))
        model.eval()
        _, F1 = model.Evaluate(dev_batches,
                               args.data_path + 'dev_eval.json',
                               answer_file='result/' +
                               args.model_dir.split('/')[-1] + '.answers',
                               drop_file=args.data_path + 'drop.json',
                               dev=args.data_path + 'dev-v2.0.json')
        best_score = F1
        with open(args.model_dir + '_f1_scores.pkl', 'rb') as f:
            f1_scores = pkl.load(f)
        with open(args.model_dir + '_em_scores.pkl', 'rb') as f:
            em_scores = pkl.load(f)
    else:
        best_score = 0.0
        f1_scores = []
        em_scores = []

    parameters = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.Adamax(parameters, lr=args.lrate)

    lrate = args.lrate

    for epoch in range(1, args.epochs + 1):
        train_batches = get_batches(train_data, args.batch_size)
        dev_batches = get_batches(dev_data, args.batch_size, evaluation=True)
        total_size = len(train_data) // args.batch_size

        model.train()
        for i, train_batch in enumerate(train_batches):
            loss = model(train_batch)
            model.zero_grad()
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(parameters, opts['grad_clipping'])
            optimizer.step()
            model.reset_parameters()

            if i % 100 == 0:
                print(
                    'Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f'
                    % (epoch, i, total_size, model.train_loss.value, lrate,
                       best_score))
                sys.stdout.flush()

        model.eval()
        exact_match_score, F1 = model.Evaluate(
            dev_batches,
            args.data_path + 'dev_eval.json',
            answer_file='result/' + args.model_dir.split('/')[-1] + '.answers',
            drop_file=args.data_path + 'drop.json',
            dev=args.data_path + 'dev-v2.0.json')
        f1_scores.append(F1)
        em_scores.append(exact_match_score)
        with open(args.model_dir + '_f1_scores.pkl', 'wb') as f:
            pkl.dump(f1_scores, f)
        with open(args.model_dir + '_em_scores.pkl', 'wb') as f:
            pkl.dump(em_scores, f)

        if best_score < F1:
            best_score = F1
            print('saving %s ...' % args.model_dir)
            torch.save(model.state_dict(), args.model_dir)
        if epoch > 0 and epoch % args.decay_period == 0:
            lrate *= args.decay
            for param_group in optimizer.param_groups:
                param_group['lr'] = lrate
예제 #18
0
def main():
    params = Params()
    img_dir = params.test['img_dir']
    label_dir = params.test['label_dir']
    save_dir = params.test['save_dir']
    if not os.path.exists(save_dir):
        os.mkdir(save_dir)
    model_path = params.test['model_path']
    save_flag = params.test['save_flag']
    tta = params.test['tta']

    params.save_params('{:s}/test_params.txt'.format(params.test['save_dir']),
                       test=True)

    # check if it is needed to compute accuracies
    eval_flag = True if label_dir else False
    if eval_flag:
        test_results = dict()
        # recall, precision, F1, dice, iou, haus
        tumor_result = utils.AverageMeter(7)
        lym_result = utils.AverageMeter(7)
        stroma_result = utils.AverageMeter(7)
        all_result = utils.AverageMeter(7)
        conf_matrix = np.zeros((3, 3))

    # data transforms
    test_transform = get_transforms(params.transform['test'])

    model_name = params.model['name']
    if model_name == 'ResUNet34':
        model = ResUNet34(params.model['out_c'],
                          fixed_feature=params.model['fix_params'])
    elif params.model['name'] == 'UNet':
        model = UNet(3, params.model['out_c'])
    else:
        raise NotImplementedError()
    model = torch.nn.DataParallel(model)
    model = model.cuda()
    cudnn.benchmark = True

    # ----- load trained model ----- #
    print("=> loading trained model")
    best_checkpoint = torch.load(model_path)
    model.load_state_dict(best_checkpoint['state_dict'])
    print("=> loaded model at epoch {}".format(best_checkpoint['epoch']))
    model = model.module

    # switch to evaluate mode
    model.eval()
    counter = 0
    print("=> Test begins:")

    img_names = os.listdir(img_dir)

    if save_flag:
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        strs = img_dir.split('/')
        prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1])
        seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1])
        if not os.path.exists(prob_maps_folder):
            os.mkdir(prob_maps_folder)
        if not os.path.exists(seg_folder):
            os.mkdir(seg_folder)

    # img_names = ['193-adca-5']
    # total_time = 0.0
    for img_name in img_names:
        # load test image
        print('=> Processing image {:s}'.format(img_name))
        img_path = '{:s}/{:s}'.format(img_dir, img_name)
        img = Image.open(img_path)
        ori_h = img.size[1]
        ori_w = img.size[0]
        name = os.path.splitext(img_name)[0]
        if eval_flag:
            label_path = '{:s}/{:s}_label.png'.format(label_dir, name)
            gt = misc.imread(label_path)

        input = test_transform((img, ))[0].unsqueeze(0)

        print('\tComputing output probability maps...')
        prob_maps = get_probmaps(input, model, params)
        if tta:
            img_hf = img.transpose(Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_vf = img.transpose(Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_hvf = img_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_hf = test_transform(
                (img_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_vf = test_transform(
                (img_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_hvf = test_transform((img_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_hf = get_probmaps(input_hf, model, params)
            prob_maps_vf = get_probmaps(input_vf, model, params)
            prob_maps_hvf = get_probmaps(input_hvf, model, params)

            # re flip
            prob_maps_hf = np.flip(prob_maps_hf, 2)
            prob_maps_vf = np.flip(prob_maps_vf, 1)
            prob_maps_hvf = np.flip(np.flip(prob_maps_hvf, 1), 2)

            # rotation 90 and flips
            img_r90 = img.rotate(90, expand=True)
            img_r90_hf = img_r90.transpose(
                Image.FLIP_LEFT_RIGHT)  # horizontal flip
            img_r90_vf = img_r90.transpose(
                Image.FLIP_TOP_BOTTOM)  # vertical flip
            img_r90_hvf = img_r90_hf.transpose(
                Image.FLIP_TOP_BOTTOM)  # horizontal and vertical flips

            input_r90 = test_transform((img_r90, ))[0].unsqueeze(0)
            input_r90_hf = test_transform(
                (img_r90_hf, ))[0].unsqueeze(0)  # horizontal flip input
            input_r90_vf = test_transform(
                (img_r90_vf, ))[0].unsqueeze(0)  # vertical flip input
            input_r90_hvf = test_transform((img_r90_hvf, ))[0].unsqueeze(
                0)  # horizontal and vertical flip input

            prob_maps_r90 = get_probmaps(input_r90, model, params)
            prob_maps_r90_hf = get_probmaps(input_r90_hf, model, params)
            prob_maps_r90_vf = get_probmaps(input_r90_vf, model, params)
            prob_maps_r90_hvf = get_probmaps(input_r90_hvf, model, params)

            # re flip
            prob_maps_r90 = np.rot90(prob_maps_r90, k=3, axes=(1, 2))
            prob_maps_r90_hf = np.rot90(np.flip(prob_maps_r90_hf, 2),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_vf = np.rot90(np.flip(prob_maps_r90_vf, 1),
                                        k=3,
                                        axes=(1, 2))
            prob_maps_r90_hvf = np.rot90(np.flip(np.flip(prob_maps_r90_hvf, 1),
                                                 2),
                                         k=3,
                                         axes=(1, 2))

            # utils.show_figures((np.array(img), np.array(img_r90_hvf),
            #                     np.swapaxes(np.swapaxes(prob_maps_r90_hvf, 0, 1), 1, 2)))

            prob_maps = (prob_maps + prob_maps_hf + prob_maps_vf +
                         prob_maps_hvf + prob_maps_r90 + prob_maps_r90_hf +
                         prob_maps_r90_vf + prob_maps_r90_hvf) / 8

        pred = np.argmax(prob_maps, axis=0)  # prediction
        pred_inside = pred.copy()
        pred_inside[pred == 4] = 0  # set contours to background
        pred_nuclei_inside_labeled = measure.label(pred_inside > 0)

        pred_tumor_inside = pred_inside == 1
        pred_lym_inside = pred_inside == 2
        pred_stroma_inside = pred_inside == 3
        pred_3types_inside = pred_tumor_inside + pred_lym_inside * 2 + pred_stroma_inside * 3

        # find the correct class for each segmented nucleus
        N_nuclei = len(np.unique(pred_nuclei_inside_labeled))
        N_class = len(np.unique(pred_3types_inside))
        intersection = np.histogram2d(pred_nuclei_inside_labeled.flatten(),
                                      pred_3types_inside.flatten(),
                                      bins=(N_nuclei, N_class))[0]
        classes = np.argmax(intersection, axis=1)
        tumor_nuclei_indices = np.nonzero(classes == 1)
        lym_nuclei_indices = np.nonzero(classes == 2)
        stroma_nuclei_indices = np.nonzero(classes == 3)

        # solve the problem of one nucleus assigned with different labels
        pred_tumor_inside = np.isin(pred_nuclei_inside_labeled,
                                    tumor_nuclei_indices)
        pred_lym_inside = np.isin(pred_nuclei_inside_labeled,
                                  lym_nuclei_indices)
        pred_stroma_inside = np.isin(pred_nuclei_inside_labeled,
                                     stroma_nuclei_indices)

        # remove small objects
        pred_tumor_inside = morph.remove_small_objects(pred_tumor_inside,
                                                       params.post['min_area'])
        pred_lym_inside = morph.remove_small_objects(pred_lym_inside,
                                                     params.post['min_area'])
        pred_stroma_inside = morph.remove_small_objects(
            pred_stroma_inside, params.post['min_area'])

        # connected component labeling
        pred_tumor_inside_labeled = measure.label(pred_tumor_inside)
        pred_lym_inside_labeled = measure.label(pred_lym_inside)
        pred_stroma_inside_labeled = measure.label(pred_stroma_inside)
        pred_all_inside_labeled = pred_tumor_inside_labeled * 3 \
                                  + (pred_lym_inside_labeled * 3 - 2) * (pred_lym_inside_labeled>0) \
                                  + (pred_stroma_inside_labeled * 3 - 1) * (pred_stroma_inside_labeled>0)

        # dilation
        pred_tumor_labeled = morph.dilation(pred_tumor_inside_labeled,
                                            selem=morph.selem.disk(
                                                params.post['radius']))
        pred_lym_labeled = morph.dilation(pred_lym_inside_labeled,
                                          selem=morph.selem.disk(
                                              params.post['radius']))
        pred_stroma_labeled = morph.dilation(pred_stroma_inside_labeled,
                                             selem=morph.selem.disk(
                                                 params.post['radius']))
        pred_all_labeled = morph.dilation(pred_all_inside_labeled,
                                          selem=morph.selem.disk(
                                              params.post['radius']))

        # utils.show_figures([pred, pred2, pred_labeled])

        if eval_flag:
            print('\tComputing metrics...')
            gt_tumor = (gt % 3 == 0) * gt
            gt_lym = (gt % 3 == 1) * gt
            gt_stroma = (gt % 3 == 2) * gt

            tumor_detect_metrics = utils.accuracy_detection_clas(
                pred_tumor_labeled, gt_tumor, clas_flag=False)
            lym_detect_metrics = utils.accuracy_detection_clas(
                pred_lym_labeled, gt_lym, clas_flag=False)
            stroma_detect_metrics = utils.accuracy_detection_clas(
                pred_stroma_labeled, gt_stroma, clas_flag=False)
            all_detect_metrics = utils.accuracy_detection_clas(
                pred_all_labeled, gt, clas_flag=True)

            tumor_seg_metrics = utils.accuracy_object_level(
                pred_tumor_labeled, gt_tumor, hausdorff_flag=False)
            lym_seg_metrics = utils.accuracy_object_level(pred_lym_labeled,
                                                          gt_lym,
                                                          hausdorff_flag=False)
            stroma_seg_metrics = utils.accuracy_object_level(
                pred_stroma_labeled, gt_stroma, hausdorff_flag=False)
            all_seg_metrics = utils.accuracy_object_level(pred_all_labeled,
                                                          gt,
                                                          hausdorff_flag=True)

            tumor_metrics = [*tumor_detect_metrics[:-1], *tumor_seg_metrics]
            lym_metrics = [*lym_detect_metrics[:-1], *lym_seg_metrics]
            stroma_metrics = [*stroma_detect_metrics[:-1], *stroma_seg_metrics]
            all_metrics = [*all_detect_metrics[:-1], *all_seg_metrics]
            conf_matrix += np.array(all_detect_metrics[-1])

            # save result for each image
            test_results[name] = {
                'tumor': tumor_metrics,
                'lym': lym_metrics,
                'stroma': stroma_metrics,
                'all': all_metrics
            }

            # update the average result
            tumor_result.update(tumor_metrics)
            lym_result.update(lym_metrics)
            stroma_result.update(stroma_metrics)
            all_result.update(all_metrics)

        # save image
        if save_flag:
            print('\tSaving image results...')
            misc.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name),
                        pred.astype(np.uint8) * 50)
            misc.imsave(
                '{:s}/{:s}_prob_tumor.png'.format(prob_maps_folder, name),
                prob_maps[1, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_lym.png'.format(prob_maps_folder, name),
                prob_maps[2, :, :])
            misc.imsave(
                '{:s}/{:s}_prob_stroma.png'.format(prob_maps_folder, name),
                prob_maps[3, :, :])
            # np.save('{:s}/{:s}_prob.npy'.format(prob_maps_folder, name), prob_maps)
            # np.save('{:s}/{:s}_seg.npy'.format(seg_folder, name), pred_all_labeled)
            final_pred = Image.fromarray(pred_all_labeled.astype(np.uint16))
            final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name))

            # save colored objects
            pred_colored = np.zeros((ori_h, ori_w, 3))
            pred_colored_instance = np.zeros((ori_h, ori_w, 3))
            pred_colored[pred_tumor_labeled > 0] = np.array([255, 0, 0])
            pred_colored[pred_lym_labeled > 0] = np.array([0, 255, 0])
            pred_colored[pred_stroma_labeled > 0] = np.array([0, 0, 255])
            filename = '{:s}/{:s}_seg_colored_3types.png'.format(
                seg_folder, name)
            misc.imsave(filename, pred_colored)
            for k in range(1, pred_all_labeled.max() + 1):
                pred_colored_instance[pred_all_labeled == k, :] = np.array(
                    utils.get_random_color())
            filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name)
            misc.imsave(filename, pred_colored_instance)

            # img_overlaid = utils.overlay_edges(label_img, pred_labeled2, img)
            # filename = '{:s}/{:s}_comparison.png'.format(seg_folder, name)
            # misc.imsave(filename, img_overlaid)

        counter += 1
        if counter % 10 == 0:
            print('\tProcessed {:d} images'.format(counter))

    # print('Time: {:4f}'.format(total_time/counter))

    print('=> Processed all {:d} images'.format(counter))
    if eval_flag:
        print(
            'Average: clas_acc\trecall\tprecision\tF1\tdice\tiou\thausdorff\n'
            'tumor: {t[0]:.4f}, {t[1]:.4f}, {t[2]:.4f}, {t[3]:.4f}, {t[4]:.4f}, {t[5]:.4f}, {t[6]:.4f}\n'
            'lym: {l[0]:.4f}, {l[1]:.4f}, {l[2]:.4f}, {l[3]:.4f}, {l[4]:.4f}, {l[5]:.4f}, {l[6]:.4f}\n'
            'stroma: {s[0]:.4f}, {s[1]:.4f}, {s[2]:.4f}, {s[3]:.4f}, {s[4]:.4f}, {s[5]:.4f}, {s[6]:.4f}\n'
            'all: {a[0]:.4f}, {a[1]:.4f}, {a[2]:.4f}, {a[3]:.4f}, {a[4]:.4f}, {a[5]:.4f}, {a[6]:.4f}'
            .format(t=tumor_result.avg,
                    l=lym_result.avg,
                    s=stroma_result.avg,
                    a=all_result.avg))

        header = [
            'clas_acc', 'recall', 'precision', 'F1', 'Dice', 'IoU', 'Hausdorff'
        ]
        save_results(header, tumor_result.avg, lym_result.avg,
                     stroma_result.avg, all_result.avg, test_results,
                     conf_matrix, '{:s}/test_result.txt'.format(save_dir))
예제 #19
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--bs', metavar='bs', type=int, default=2)
    parser.add_argument('--path', type=str, default='../../data')
    parser.add_argument('--results', type=str, default='../../results/model')
    parser.add_argument('--nw', type=int, default=0)
    parser.add_argument('--max_images', type=int, default=None)
    parser.add_argument('--val_size', type=int, default=None)
    parser.add_argument('--epochs', type=int, default=10)
    parser.add_argument('--lr', type=float, default=0.003)
    parser.add_argument('--lr_decay', type=float, default=0.99997)
    parser.add_argument('--kernel_lvl', type=float, default=1)
    parser.add_argument('--noise_lvl', type=float, default=1)
    parser.add_argument('--motion_blur', type=bool, default=False)
    parser.add_argument('--homo_align', type=bool, default=False)
    parser.add_argument('--resume', type=bool, default=False)

    args = parser.parse_args()

    print()
    print(args)
    print()

    if not os.path.isdir(args.results): os.makedirs(args.results)

    PATH = args.results
    if not args.resume:
        f = open(PATH + "/param.txt", "a+")
        f.write(str(args))
        f.close()

    writer = SummaryWriter(PATH + '/runs')

    # CUDA for PyTorch
    use_cuda = torch.cuda.is_available()
    device = torch.device('cuda:0' if use_cuda else "cpu")

    # Parameters
    params = {'batch_size': args.bs, 'shuffle': True, 'num_workers': args.nw}

    # Generators
    print('Initializing training set')
    training_set = Dataset(args.path + '/train/', args.max_images,
                           args.kernel_lvl, args.noise_lvl, args.motion_blur,
                           args.homo_align)
    training_generator = data.DataLoader(training_set, **params)

    print('Initializing validation set')
    validation_set = Dataset(args.path + '/test/', args.val_size,
                             args.kernel_lvl, args.noise_lvl, args.motion_blur,
                             args.homo_align)

    validation_generator = data.DataLoader(validation_set, **params)

    # Model
    model = UNet(in_channel=3, out_channel=3)
    if args.resume:
        models_path = get_newest_model(PATH)
        print('loading model from ', models_path)
        model.load_state_dict(torch.load(models_path))

    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = torch.nn.DataParallel(model)

    model.to(device)

    # Loss + optimizer
    criterion = BurstLoss()
    optimizer = RAdam(model.parameters(), lr=args.lr)

    scheduler = StepLR(optimizer, step_size=8 // args.bs, gamma=args.lr_decay)
    if args.resume:
        n_iter = np.loadtxt(PATH + '/train.txt', delimiter=',')[:, 0][-1]
    else:
        n_iter = 0

    # Loop over epochs
    for epoch in range(args.epochs):
        train_loss = 0.0

        # Training
        model.train()
        for i, (X_batch, y_labels) in enumerate(training_generator):
            # Alter the burst length for each mini batch

            burst_length = np.random.randint(2, 9)
            X_batch = X_batch[:, :burst_length, :, :, :]

            # Transfer to GPU
            X_batch, y_labels = X_batch.to(device).type(
                torch.float), y_labels.to(device).type(torch.float)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            pred = model(X_batch)
            loss = criterion(pred, y_labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

            train_loss += loss.detach().cpu().numpy()
            writer.add_scalar('training_loss', loss.item(), n_iter)

            if i % 100 == 0 and i > 0:
                loss_printable = str(np.round(train_loss, 2))

                f = open(PATH + "/train.txt", "a+")
                f.write(str(n_iter) + "," + loss_printable + "\n")
                f.close()

                print("training loss ", loss_printable)

                train_loss = 0.0

            if i % 1000 == 0:
                if torch.cuda.device_count() > 1:
                    torch.save(
                        model.module.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))
                else:
                    torch.save(
                        model.state_dict(),
                        os.path.join(PATH,
                                     'model_' + str(int(n_iter)) + '.pt'))

            if i % 1000 == 0:
                # Validation
                val_loss = 0.0
                with torch.set_grad_enabled(False):
                    model.eval()
                    for v, (X_batch,
                            y_labels) in enumerate(validation_generator):
                        # Alter the burst length for each mini batch

                        burst_length = np.random.randint(2, 9)
                        X_batch = X_batch[:, :burst_length, :, :, :]

                        # Transfer to GPU
                        X_batch, y_labels = X_batch.to(device).type(
                            torch.float), y_labels.to(device).type(torch.float)

                        # forward + backward + optimize
                        pred = model(X_batch)
                        loss = criterion(pred, y_labels)

                        val_loss += loss.detach().cpu().numpy()

                        if v < 5:
                            im = make_im(pred, X_batch, y_labels)
                            writer.add_image('image_' + str(v), im, n_iter)

                    writer.add_scalar('validation_loss', val_loss, n_iter)

                    loss_printable = str(np.round(val_loss, 2))
                    print('validation loss ', loss_printable)

                    f = open(PATH + "/eval.txt", "a+")
                    f.write(str(n_iter) + "," + loss_printable + "\n")
                    f.close()

            n_iter += args.bs
예제 #20
0
        optimizer.step()
        train_loss[-1] += float(loss)
    train_loss[-1] = train_loss[-1] / (i + 1)
    val_loss.append(evaluate(unet, datapath='.', split='val', device=device))
    # clear_output()
    plot_losses(train_loss, val_loss)
    if epoch == 0:
        print(f"Time per epoch: {time.time() - start:.0f} s")

# Plot some examples
dataset = SegmentationDataset(root='../data',
                              year='2009',
                              image_set='train',
                              transform=transform_test)
loader = DataLoader(dataset, batch_size=32, shuffle=True)
unet.eval()
with torch.no_grad():
    for (images, segmentations) in loader:
        predictions = unet(images)
        predictions = decode_seg_maps(predictions)
        segmentations = decode_seg_maps(encode_images(segmentations))
        for i, image in enumerate(images):
            fig = plt.figure(figsize=(12, 4))
            plt.subplot(1, 3, 1)
            plt.imshow(image.permute(1, 2, 0).numpy())
            plt.subplot(1, 3, 2)
            plt.imshow(segmentations[i])
            plt.subplot(1, 3, 3)
            plt.imshow(predictions[i])
            plt.show()
예제 #21
0
def train(args):
    """
    Train UNet from datasets
    """

    # dataset
    print('Reading dataset from {}...'.format(args.dataset_path))
    train_dataset = SSDataset(dataset_path=args.dataset_path, is_train=True)
    val_dataset = SSDataset(dataset_path=args.dataset_path, is_train=False)
    train_dataloader = DataLoader(dataset=train_dataset,
                                  batch_size=args.batch_size,
                                  shuffle=True)
    val_dataloader = DataLoader(dataset=val_dataset,
                                batch_size=args.batch_size,
                                shuffle=False)

    # mask
    with open(args.mask_json_path, 'w', encoding='utf-8') as mask:
        colors = SSDataset.all_colors
        mask.write(json.dumps(colors))
        print('Mask colors list has been saved in {}'.format(
            args.mask_json_path))

    # model
    net = UNet(in_channels=3, out_channels=5)
    if args.cuda:
        net = net.cuda()

    # setting
    lr = args.lr  # 1e-3
    optimizer = optim.Adam(net.parameters(), lr=lr)
    criterion = loss_fn

    # run
    train_losses = []
    val_losses = []
    print('Start training...')
    for epoch_idx in range(args.epochs):
        # train
        net.train()
        train_loss = 0
        for batch_idx, batch_data in enumerate(train_dataloader):
            xs, ys = batch_data
            if args.cuda:
                xs = xs.cuda()
                ys = ys.cuda()
            ys_pred = net(xs)
            loss = criterion(ys_pred, ys)
            train_loss += loss

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # val
        net.eval()
        val_loss = 0
        for batch_idx, batch_data in enumerate(val_dataloader):
            xs, ys = batch_data
            if args.cuda:
                xs = xs.cuda()
                ys = ys.cuda()
            ys_pred = net(xs)
            loss = loss_fn(ys_pred, ys)
            val_loss += loss

        train_losses.append(train_loss)
        val_losses.append(val_loss)
        print('Epoch: {}, Train total loss: {}, Val total loss: {}'.format(
            epoch_idx + 1, train_loss.item(), val_loss.item()))

        # save
        if (epoch_idx + 1) % args.save_epoch == 0:
            checkpoint_path = os.path.join(
                args.checkpoint_path,
                'checkpoint_{}.pth'.format(epoch_idx + 1))
            torch.save(net.state_dict(), checkpoint_path)
            print('Saved Checkpoint at Epoch {} to {}'.format(
                epoch_idx + 1, checkpoint_path))

    # summary
    if args.do_save_summary:
        epoch_range = list(range(1, args.epochs + 1))
        plt.plot(epoch_range, train_losses, 'r', label='Train loss')
        plt.plot(epoch_range, val_loss, 'g', label='Val loss')
        plt.imsave(args.summary_image)
        print('Summary images have been saved in {}'.format(
            args.summary_image))

    # save
    net.eval()
    torch.save(net.state_dict(), args.model_state_dict)
    print('Saved state_dict in {}'.format(args.model_state_dict))
class Instructor:
    ''' Model training and evaluation '''
    def __init__(self, opt):
        self.opt = opt
        if opt.inference:
            self.testset = TestImageDataset(fdir=opt.impaths['test'],
                                            imsize=opt.imsize)
        else:
            self.trainset = ImageDataset(fdir=opt.impaths['train'],
                                         bdir=opt.impaths['btrain'],
                                         imsize=opt.imsize,
                                         mode='train',
                                         aug_prob=opt.aug_prob,
                                         prefetch=opt.prefetch)
            self.valset = ImageDataset(fdir=opt.impaths['val'],
                                       bdir=opt.impaths['bval'],
                                       imsize=opt.imsize,
                                       mode='val',
                                       aug_prob=opt.aug_prob,
                                       prefetch=opt.prefetch)
        self.model = UNet(n_channels=3,
                          n_classes=1,
                          bilinear=self.opt.use_bilinear)
        if opt.checkpoint:
            self.model.load_state_dict(
                torch.load('./state_dict/{:s}'.format(opt.checkpoint),
                           map_location=self.opt.device))
            print('checkpoint {:s} has been loaded'.format(opt.checkpoint))
        if opt.multi_gpu == 'on':
            self.model = torch.nn.DataParallel(self.model)
        self.model = self.model.to(opt.device)
        self._print_args()

    def _print_args(self):
        n_trainable_params, n_nontrainable_params = 0, 0
        for p in self.model.parameters():
            n_params = torch.prod(torch.tensor(p.shape))
            if p.requires_grad:
                n_trainable_params += n_params
            else:
                n_nontrainable_params += n_params
        self.info = 'n_trainable_params: {0}, n_nontrainable_params: {1}\n'.format(
            n_trainable_params, n_nontrainable_params)
        self.info += 'training arguments:\n' + '\n'.join([
            '>>> {0}: {1}'.format(arg, getattr(self.opt, arg))
            for arg in vars(self.opt)
        ])
        if self.opt.device.type == 'cuda':
            print('cuda memory allocated:',
                  torch.cuda.memory_allocated(opt.device.index))
        print(self.info)

    def _reset_records(self):
        self.records = {
            'best_epoch': 0,
            'best_dice': 0,
            'train_loss': list(),
            'val_loss': list(),
            'val_dice': list(),
            'checkpoints': list()
        }

    def _update_records(self, epoch, train_loss, val_loss, val_dice):
        if val_dice > self.records['best_dice']:
            path = './state_dict/{:s}_dice{:.4f}_temp{:s}.pt'.format(
                self.opt.model_name, val_dice,
                str(time.time())[-6:])
            if self.opt.multi_gpu == 'on':
                torch.save(self.model.module.state_dict(), path)
            else:
                torch.save(self.model.state_dict(), path)
            self.records['best_epoch'] = epoch
            self.records['best_dice'] = val_dice
            self.records['checkpoints'].append(path)
        self.records['train_loss'].append(train_loss)
        self.records['val_loss'].append(val_loss)
        self.records['val_dice'].append(val_dice)

    def _draw_records(self):
        timestamp = str(int(time.time()))
        print('best epoch: {:d}'.format(self.records['best_epoch']))
        print('best train loss: {:.4f}, best val loss: {:.4f}'.format(
            min(self.records['train_loss']), min(self.records['val_loss'])))
        print('best val dice {:.4f}'.format(self.records['best_dice']))
        os.rename(
            self.records['checkpoints'][-1],
            './state_dict/{:s}_dice{:.4f}_save{:s}.pt'.format(
                self.opt.model_name, self.records['best_dice'], timestamp))
        for path in self.records['checkpoints'][0:-1]:
            os.remove(path)
        # Draw figures
        plt.figure()
        trainloss, = plt.plot(self.records['train_loss'])
        valloss, = plt.plot(self.records['val_loss'])
        plt.legend([trainloss, valloss], ['train', 'val'], loc='upper right')
        plt.title('{:s} loss curve'.format(timestamp))
        plt.savefig('./figs/{:s}_loss.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        plt.figure()
        valdice, = plt.plot(self.records['val_dice'])
        plt.title('{:s} dice curve'.format(timestamp))
        plt.savefig('./figs/{:s}_dice.png'.format(timestamp),
                    format='png',
                    transparent=True,
                    dpi=300)
        # Save report
        report = '\t'.join(
            ['val_dice', 'train_loss', 'val_loss', 'best_epoch', 'timestamp'])
        report += "\n{:.4f}\t{:.4f}\t{:.4f}\t{:d}\t{:s}\n{:s}".format(
            self.records['best_dice'], min(self.records['train_loss']),
            min(self.records['val_loss']), self.records['best_epoch'],
            timestamp, self.info)
        with open('./logs/{:s}_log.txt'.format(timestamp), 'w') as f:
            f.write(report)
        print('report saved:', './logs/{:s}_log.txt'.format(timestamp))

    def _train(self, train_dataloader, criterion, optimizer):
        self.model.train()
        train_loss, n_total, n_batch = 0, 0, len(train_dataloader)
        for i_batch, sample_batched in enumerate(train_dataloader):
            inputs, target = sample_batched[0].to(
                self.opt.device), sample_batched[1].to(self.opt.device)
            predict = self.model(inputs)

            optimizer.zero_grad()
            loss = criterion(predict, target)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(sample_batched)
            n_total += len(sample_batched)

            ratio = int((i_batch + 1) * 50 / n_batch)
            sys.stdout.write("\r[" + ">" * ratio + " " * (50 - ratio) +
                             "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                                      (i_batch + 1) * 100 /
                                                      n_batch))
            sys.stdout.flush()
        print()
        return train_loss / n_total

    def _evaluation(self, val_dataloader, criterion):
        self.model.eval()
        val_loss, val_dice, n_total = 0, 0, 0
        with torch.no_grad():
            for sample_batched in val_dataloader:
                inputs, target = sample_batched[0].to(
                    self.opt.device), sample_batched[1].to(self.opt.device)
                predict = self.model(inputs)
                loss = criterion(predict, target)
                dice = dice_coeff(predict, target)
                val_loss += loss.item() * len(sample_batched)
                val_dice += dice.item() * len(sample_batched)
                n_total += len(sample_batched)
        return val_loss / n_total, val_dice / n_total

    def run(self):
        _params = filter(lambda p: p.requires_grad, self.model.parameters())
        optimizer = torch.optim.Adam(_params,
                                     lr=self.opt.lr,
                                     weight_decay=self.opt.l2reg)
        criterion = BCELoss2d()
        train_dataloader = DataLoader(dataset=self.trainset,
                                      batch_size=self.opt.batch_size,
                                      shuffle=True)
        val_dataloader = DataLoader(dataset=self.valset,
                                    batch_size=self.opt.batch_size,
                                    shuffle=False)
        self._reset_records()
        for epoch in range(self.opt.num_epoch):
            train_loss = self._train(train_dataloader, criterion, optimizer)
            val_loss, val_dice = self._evaluation(val_dataloader, criterion)
            self._update_records(epoch, train_loss, val_loss, val_dice)
            print(
                '{:d}/{:d} > train loss: {:.4f}, val loss: {:.4f}, val dice: {:.4f}'
                .format(epoch + 1, self.opt.num_epoch, train_loss, val_loss,
                        val_dice))
        self._draw_records()

    def inference(self):
        test_dataloader = DataLoader(dataset=self.testset,
                                     batch_size=1,
                                     shuffle=False)
        n_batch = len(test_dataloader)
        with torch.no_grad():
            for i_batch, sample_batched in enumerate(test_dataloader):
                index, inputs = sample_batched[0], sample_batched[1].to(
                    self.opt.device)
                predict = self.model(inputs)
                self.testset.save_img(index.item(), predict, self.opt.use_crf)
                ratio = int((i_batch + 1) * 50 / n_batch)
                sys.stdout.write(
                    "\r[" + ">" * ratio + " " * (50 - ratio) +
                    "] {}/{} {:.2f}%".format(i_batch + 1, n_batch,
                                             (i_batch + 1) * 100 / n_batch))
                sys.stdout.flush()
        print()
예제 #23
0
criterionMSE = nn.MSELoss() #.to(device)


transform = transforms.Compose(transform_list)

img_dir = open('/n/holyscratch01/wadduwage_lab/uom_bme/dataset_static_2020/20200105_synthBeads_1/test.txt','r')

avg_mse = 0
avg_psnr = 0

for epochs in range(55,56):
    my_model = 'ckpt/train_deep_tfm_loss_mse/fcn_deep_' + str(epochs) + '.pth'
    netG = UNet(n_classes=args.output_nc)
    netG.load_state_dict(torch.load(my_model))
    netG.eval()
    p = 0
    f_path = '/n/holyscratch01/wadduwage_lab/uom_bme/dataset_static_2020/20200105_synthBeads_1/tr_data_1sls/'    
    for line in img_dir:
        print(line)
        GT_ = io.imread(f_path + str(line[0:-1]) + '_gt.png')
        modalities = np.zeros((32,128,128))
        for i in range(0,32):
             modalities[i,:,:] = io.imread(f_path + str(line[0:-1]) +'_'+str(i+1) +'.png')  
        depth = modalities.shape[2]
        predicted_im = np.zeros((128,128,1))
        if np.min(np.array(GT_))==np.max(np.array(GT_)):
             print('Yes')
        GT = torch.from_numpy(np.divide(GT_,max_gt))
        img = torch.from_numpy(np.divide(modalities,max_im)[None, :, :]).float()
        netG = netG.cuda()
예제 #24
0
test_label = labels[train_num_pts:]

all_data = all_data[:train_num_pts]
labels = labels[:train_num_pts]
numpts = len(all_data)
indices = list(range(len(all_data)))

if LOSS_NUM == 3 or LOSS_NUM == 4:
    test_weights = weights[train_num_pts:]
    weights = weights[:train_num_pts]

# Start training
print("Training model")
for epoch in range(1, n_epochs+1):

    model.eval()
    test_loss = 0.0
    dice_val = 0.0
    numpts = len(test_data)

    for i in range(numpts):
        data = test_data[i:i+1]
        target_pt = test_label[i:i+1]

        if LOSS_NUM == 3 or LOSS_NUM == 4:
            test_wt = test_label[i:i+1]
            wt = torch.from_numpy(test_wt).float().cuda()

        data = torch.from_numpy(data).float()
        target = torch.from_numpy(target_pt).float()
예제 #25
0
def run_inference(args):
    model = UNet(input_channels=3, num_classes=3)
    model.load_state_dict(torch.load(args.model_path, map_location='cpu'),
                          strict=False)
    print('Log: Loaded pretrained {}'.format(args.model_path))
    model.eval()
    # annus/Desktop/palsar/
    test_image_path = '/home/annus/Desktop/palsar/palsar_dataset_full/palsar_dataset/palsar_{}_region_{}.tif'.format(
        args.year, args.region)
    test_label_path = '/home/annus/Desktop/palsar/palsar_dataset_full/palsar_dataset/fnf_{}_region_{}.tif'.format(
        args.year, args.region)
    inference_loader = get_inference_loader(image_path=test_image_path,
                                            label_path=test_label_path,
                                            model_input_size=128,
                                            num_classes=4,
                                            one_hot=True,
                                            batch_size=args.bs,
                                            num_workers=4)
    # we need to fill our new generated test image
    generated_map = np.empty(shape=inference_loader.dataset.get_image_size())
    weights = torch.Tensor([1, 1, 2])
    focal_criterion = FocalLoss2d(weight=weights)
    un_confusion_meter = tnt.meter.ConfusionMeter(2, normalized=False)
    confusion_meter = tnt.meter.ConfusionMeter(2, normalized=True)
    total_correct, total_examples = 0, 0
    net_loss = []
    for idx, data in enumerate(inference_loader):
        coordinates, test_x, label = data['coordinates'].tolist(
        ), data['input'], data['label']
        out_x, softmaxed = model.forward(test_x)
        pred = torch.argmax(softmaxed, dim=1)
        not_one_hot_target = torch.argmax(label, dim=1)
        # convert to binary classes
        # 0-> noise, 1-> forest, 2-> non-forest, 3-> water
        pred[pred == 0] = 2
        pred[pred == 3] = 2
        not_one_hot_target[not_one_hot_target == 0] = 2
        not_one_hot_target[not_one_hot_target == 3] = 2
        # now convert 1, 2 to 0, 1
        pred -= 1
        not_one_hot_target -= 1
        pred_numpy = pred.numpy().transpose(1, 2, 0)
        for k in range(test_x.shape[0]):
            x, x_, y, y_ = coordinates[k]
            generated_map[x:x_, y:y_] = pred_numpy[:, :, k]
        loss = focal_criterion(
            softmaxed,
            not_one_hot_target)  # dice_criterion(softmaxed, label) #
        accurate = (pred == not_one_hot_target).sum().item()
        numerator = float(accurate)
        denominator = float(
            pred.view(-1).size(0))  # test_x.size(0) * dimension ** 2)
        total_correct += numerator
        total_examples += denominator
        net_loss.append(loss.item())
        un_confusion_meter.add(predicted=pred.view(-1),
                               target=not_one_hot_target.view(-1))
        confusion_meter.add(predicted=pred.view(-1),
                            target=not_one_hot_target.view(-1))
        # if idx % 5 == 0:
        accuracy = float(numerator) * 100 / denominator
        print(
            '{}, {} -> ({}/{}) output size = {}, loss = {}, accuracy = {}/{} = {:.2f}%'
            .format(args.year, args.region, idx, len(inference_loader),
                    out_x.size(), loss.item(), numerator, denominator,
                    accuracy))
        #################################
    mean_accuracy = total_correct * 100 / total_examples
    mean_loss = np.asarray(net_loss).mean()
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('log: test:: total loss = {:.5f}, total accuracy = {:.5f}%'.format(
        mean_loss, mean_accuracy))
    print('$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$')
    print('---> Confusion Matrix:')
    print(confusion_meter.value())
    # class_names = ['background/clutter', 'buildings', 'trees', 'cars',
    #                'low_vegetation', 'impervious_surfaces', 'noise']
    with open('normalized.pkl', 'wb') as this:
        pkl.dump(confusion_meter.value(), this, protocol=pkl.HIGHEST_PROTOCOL)
    with open('un_normalized.pkl', 'wb') as this:
        pkl.dump(un_confusion_meter.value(),
                 this,
                 protocol=pkl.HIGHEST_PROTOCOL)

    # save_path = 'generated_maps/generated_{}_{}.npy'.format(args.year, args.region)
    save_path = '/home/annus/Desktop/palsar/generated_maps/using_separate_models/generated_{}_{}.npy'.format(
        args.year, args.region)
    np.save(save_path, generated_map)
    #########################################################################################3
    inference_loader.dataset.clear_mem()
    pass
예제 #26
0
def train(args):
    dataset = open("dataset.csv", "r").readlines()
    train_set = dataset[:600]
    val_set = dataset[600:]
    root_dir = root_dir = "data/Lung_Segmentation/"

    train_data = LungSegmentationDataGen(train_set, root_dir, args)
    val_data = LungSegmentationDataGen(val_set, root_dir, args)

    train_dataloader = DataLoader(train_data,
                                  batch_size=5,
                                  shuffle=True,
                                  num_workers=4)

    val_dataloader = DataLoader(val_data,
                                batch_size=5,
                                shuffle=True,
                                num_workers=4)

    dataloaders = {"train": train_dataloader, "val": val_dataloader}

    dataset_sizes = {"train": len(train_set), "val": len(val_set)}

    print("dataset_sizes: {}".format(dataset_sizes))

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    model = UNet(in_channels=1)
    model = model.to(device)

    optimizer = optim.Adam(model.parameters())

    loss_train = []
    loss_valid = []

    current_mean_dsc = 0.0
    best_validation_dsc = 0.0

    epochs = args.epochs
    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)
        dice_score_list = []
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()  # Set model to evaluate mode

            # Iterate over data.
            for i, data in enumerate(dataloaders[phase]):
                inputs, y_true = data
                inputs = inputs.to(device)
                y_true = y_true.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    # forward pass with batch input
                    y_pred = model(inputs)

                    loss = dice_loss(y_true, y_pred)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        # print("step: {}, train_loss: {}".format(i, loss))
                        loss_train.append(loss.item())

                        # calculate the gradients based on loss
                        loss.backward()

                        # update the weights
                        optimizer.step()

                    if phase == "val":
                        loss_valid.append(loss.item())
                        dsc = dice_score(y_true, y_pred)
                        print("step: {}, val_loss: {}, val dice_score: {}".
                              format(i, loss, dsc))
                        dice_score_list.append(dsc.detach().numpy())

                if phase == "train" and (i + 1) % 10 == 0:
                    print("step:{}, train_loss: {}".format(
                        i + 1, np.mean(loss_train)))
                    loss_train = []
            if phase == "val":
                print("mean val_loss: {}".format(np.mean(loss_valid)))
                loss_valid = []
                current_mean_dsc = np.mean(dice_score_list)
                print("validation set dice_score: {}".format(current_mean_dsc))
                if current_mean_dsc > best_validation_dsc:
                    best_validation_dsc = current_mean_dsc
                    print("best dice_score on val set: {}".format(
                        best_validation_dsc))
                    model_name = "unet_{0:.2f}.pt".format(best_validation_dsc)
                    torch.save(model.state_dict(),
                               os.path.join(args.weights, model_name))
    def UNet(self):
        if self.mode == 'train':
            transform = transforms.Compose([Normalization(mean=0.5, std=0.5, mode='train'), ToTensor()])

            dataset_train = Dataset(mode = self.mode, data_dir=self.data_dir, image_type = self.image_type, transform=transform)
            loader_train = DataLoader(dataset_train, batch_size=self.batch_size, shuffle=True, num_workers=8)

            # dataset_val = Dataset(data_dir=os.path.join(data_dir, 'val'), transform=transform)
            # loader_val = DataLoader(dataset_val, batch_size=batch_size, shuffle=False, num_workers=8)

            # 그밖에 부수적인 variables 설정하기
            num_data_train = len(dataset_train)
            # num_data_val = len(dataset_val)

            num_batch_train = np.ceil(num_data_train / self.batch_size)
            # num_batch_val = np.ceil(num_data_val / batch_size)
            
        elif self.mode == 'test':
            transform = transforms.Compose([Normalization(mean=0.5, std=0.5, mode='test'), ToTensor()])

            dataset_test = Dataset(mode = self.mode, data_dir=self.data_dir, image_type = self.image_type, transform=transform)
            loader_test = DataLoader(dataset_test, batch_size=self.batch_size, shuffle=False, num_workers=8)

            # 그밖에 부수적인 variables 설정하기
            num_data_test = len(dataset_test)

            num_batch_test = np.ceil(num_data_test / self.batch_size)
        

        fn_tonumpy = lambda x: x.to('cpu').detach().numpy().transpose(0, 2, 3, 1)
        fn_denorm = lambda x, mean, std: (x * std) + mean
        fn_class = lambda x: 1.0 * (x > 0.5)
        
        
        net = UNet().to(self.device)
        
        criterion = torch.nn.MSELoss().to(self.device)
            
        optimizer = torch.optim.Adam(net.parameters(), lr=self.lr)

        

        writer_train = SummaryWriter(log_dir=os.path.join(self.log_dir, 'train'))

            
            
        if self.mode == 'train':
            if self.train_continue == "on":
                net, optimizer = load_model(ckpt_dir=self.ckpt_dir, net=net, optim=optimizer)

            for epoch in range(1, self.num_epoch + 1):
                net.train()
                loss_arr = []

                for batch, data in enumerate(loader_train, 1):
                    # forward pass
                    label = data['label'].to(self.device)
                    input = data['input'].to(self.device)

                    output = net(input)

                    # backward pass
                    optimizer.zero_grad()

                    loss = criterion(output, label)
                    loss.backward()

                    optimizer.step()

                    # 손실함수 계산
                    loss_arr += [loss.item()]

                    

                    # Tensorboard 저장하기
                    label = fn_tonumpy(label)
                    input = fn_tonumpy(fn_denorm(input, mean=0.5, std=0.5))
                    output = fn_tonumpy(fn_class(output))

                    writer_train.add_image('label', label, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')
                    writer_train.add_image('input', input, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')
                    writer_train.add_image('output', output, num_batch_train * (epoch - 1) + batch, dataformats='NHWC')

                writer_train.add_scalar('loss', np.mean(loss_arr), epoch)
                
                print("TRAIN: EPOCH %04d / %04d |  LOSS %.4f" %(epoch, self.num_epoch, np.mean(loss_arr)))
                
                if epoch % 20 == 0:
                    save_model(ckpt_dir=self.ckpt_dir, net=net, optim=optimizer, epoch=0)
                    
            writer_train.close()


        # TEST MODE
        elif self.mode == 'test':
            net, optimizer = load_model(ckpt_dir=self.ckpt_dir, net=net, optim=optimizer)

            with torch.no_grad():
                net.eval()
                loss_arr = []
                id = 1
                for batch, data in enumerate(loader_test, 1):
                    # forward pass
                    input = data['input'].to(self.device)

                    output = net(input)

                    # 손실함수 계산하기
                    #loss = criterion(output, label)

                    #loss_arr += [loss.item()]

                    #print("TEST: BATCH %04d / %04d | " %
                    #    (batch, num_batch_test))

                    # Tensorboard 저장하기
                    output = fn_tonumpy(fn_class(output))
                    
                    for j in range(input.shape[0]):
                        if id == 800:
                            id = 2350
                        print(id)
                        #plt.imsave(os.path.join(self.result_dir, 'png', 'label_%04d.png' % id), label[j].squeeze(), cmap='gray')
                        #plt.imsave(os.path.join(self.result_dir, 'png', 'input_%04d.png' % id), input[j].squeeze(), cmap='gray')
                        plt.imsave(os.path.join(self.result_dir, 'png', 'gt%06d.png' % id), output[j].squeeze(), cmap='gray')
                        id+=1
                        # np.save(os.path.join(result_dir, 'numpy', 'label_%04d.npy' % id), label[j].squeeze())
                        # np.save(os.path.join(result_dir, 'numpy', 'input_%04d.npy' % id), input[j].squeeze())
                        # np.save(os.path.join(result_dir, 'numpy', 'output_%04d.npy' % id), output[j].squeeze())

            print("AVERAGE TEST: BATCH %04d / %04d | LOSS %.4f" %
                (batch, num_batch_test, np.mean(loss_arr)))
예제 #28
0
                               img_tensor=label,
                               global_step=global_step,
                               dataformats='NHWC')
        train_writer.add_image(tag='output',
                               img_tensor=output,
                               global_step=global_step,
                               dataformats='NHWC')

    train_loss_avg = np.mean(train_loss_arr)
    train_writer.add_scalar(tag='loss',
                            scalar_value=train_loss_avg,
                            global_step=epoch)

    # Validation (No Back Propagation)
    with torch.no_grad():
        net.eval()  # Evaluation Mode
        val_loss_arr = list()

        for batch_idx, data in enumerate(val_loader, 1):
            # Forward Propagation
            img = data['img'].to(device)
            label = data['label'].to(device)

            output = net(img)

            # Calc Loss Function
            loss = loss_fn(output, label)
            val_loss_arr.append(loss.item())

            print_form = '[Validation] | Epoch: {:0>4d} / {:0>4d} | Batch: {:0>4d} / {:0>4d} | Loss: {:.4f}'
            print(
예제 #29
0
파일: train.py 프로젝트: agupt013/ALANET
def main(args):
    writer = SummaryWriter(os.path.join('./logs'))
    # torch.backends.cudnn.benchmark = True
    if not os.path.isdir(args.checkpoint_dir):
        os.mkdir(args.checkpoint_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('[MODEL] CUDA DEVICE : {}'.format(device))

    # TODO DEFINE TRAIN AND TEST TRANSFORMS
    train_tf = None
    test_tf = None

    # Channel wise mean calculated on adobe240-fps training dataset
    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean, std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    test_valid = 'validation' if args.valid else 'test'
    train_data = BlurDataset(os.path.join(args.dataset_root, 'train'),
                             seq_len=args.sequence_length,
                             tau=args.num_frame_blur,
                             delta=5,
                             transform=train_tf)
    test_data = BlurDataset(os.path.join(args.dataset_root, test_valid),
                            seq_len=args.sequence_length,
                            tau=args.num_frame_blur,
                            delta=5,
                            transform=train_tf)

    train_loader = DataLoader(train_data,
                              batch_size=args.train_batch_size,
                              shuffle=True)
    test_loader = DataLoader(test_data,
                             batch_size=args.test_batch_size,
                             shuffle=False)

    # TODO IMPORT YOUR CUSTOM MODEL
    model = UNet(3, 3, device, decode_mode=args.decode_mode)

    if args.checkpoint:
        store_dict = torch.load(args.checkpoint)
        try:
            model.load_state_dict(store_dict['state_dict'])
        except KeyError:
            model.load_state_dict(store_dict)

    if args.train_continue:
        store_dict = torch.load(args.checkpoint)
        model.load_state_dict(store_dict['state_dict'])

    else:
        store_dict = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}

    model.to(device)
    model.train(True)

    # model = nn.DataParallel(model)

    # TODO DEFINE MORE CRITERIA
    # input(True if device == torch.device('cuda:0') else False)
    criterion = {
        'MSE': nn.MSELoss(),
        'L1': nn.L1Loss(),
        # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=True,
        #                              use_gpu=True if device == torch.device('cuda:0') else False)
    }

    criterion_w = {'MSE': 1.0, 'L1': 10.0, 'Perceptual': 10.0}

    # Define optimizers
    # optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9,weight_decay=5e-4)
    optimizer = optim.Adam(model.parameters(), lr=args.init_learning_rate)

    # Define lr scheduler
    scheduler = optim.lr_scheduler.MultiStepLR(optimizer,
                                               milestones=args.milestones,
                                               gamma=0.1)

    # best_acc = 0.0
    # start = time.time()
    cLoss = store_dict['loss']
    valLoss = store_dict['valLoss']
    valPSNR = store_dict['valPSNR']
    checkpoint_counter = 0

    loss_tracker = {}
    loss_tracker_test = {}

    psnr_old = 0.0
    dssim_old = 0.0

    for epoch in range(1, 10 *
                       args.epochs):  # loop over the dataset multiple times

        # Append and reset
        cLoss.append([])
        valLoss.append([])
        valPSNR.append([])
        running_loss = 0

        # Increment scheduler count
        scheduler.step()

        tqdm_loader = tqdm(range(len(train_loader)), ncols=150)

        loss = 0.0
        psnr_ = 0.0
        dssim_ = 0.0

        loss_tracker = {}
        for loss_fn in criterion.keys():
            loss_tracker[loss_fn] = 0.0

        # Train
        model.train(True)
        total_steps = 0.01
        total_steps_test = 0.01
        '''for train_idx, data in enumerate(train_loader, 1):
            loss = 0.0
            blur_data, sharpe_data = data
            #import pdb; pdb.set_trace()
            # input(sharpe_data.shape)
            #import pdb; pdb.set_trace()
            interp_idx = int(math.ceil((args.num_frame_blur/2) - 0.49))
            #input(interp_idx)
            if args.decode_mode == 'interp':
                sharpe_data = sharpe_data[:, :, 1::2, :, :]
            elif args.decode_mode == 'deblur':
                sharpe_data = sharpe_data[:, :, 0::2, :, :]
            else:
                #print('\nBoth\n')
                sharpe_data = sharpe_data

            #print(sharpe_data.shape)
            #input(blur_data.shape)
            blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            try:
                sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            except:
                sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

            # clear gradient
            optimizer.zero_grad()

            # forward pass
            sharpe_out = model(blur_data)
            # import pdb; pdb.set_trace()
            # input(sharpe_out.shape)

            # compute losses
            # import pdb;
            # pdb.set_trace()
            sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
            B, C, S, Fx, Fy = sharpe_out.shape
            for loss_fn in criterion.keys():
                loss_tmp = 0.0

                if loss_fn == 'Perceptual':
                    for bidx in range(B):
                        loss_tmp += criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                      sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                    # loss_tmp /= B
                else:
                    loss_tmp = criterion_w[loss_fn] * \
                               criterion[loss_fn](sharpe_out, sharpe_data)


                # try:
                # import pdb; pdb.set_trace()
                loss += loss_tmp # if
                # except :
                try:
                    loss_tracker[loss_fn] += loss_tmp.item()
                except KeyError:
                    loss_tracker[loss_fn] = loss_tmp.item()

            # Backpropagate
            loss.backward()
            optimizer.step()

            # statistics
            # import pdb; pdb.set_trace()
            sharpe_out = sharpe_out.detach().cpu().numpy()
            sharpe_data = sharpe_data.cpu().numpy()
            for sidx in range(S):
                for bidx in range(B):
                    psnr_ += psnr(sharpe_out[bidx, :, sidx, :, :], sharpe_data[bidx, :, sidx, :, :]) #, peak=1.0)
                    """dssim_ += dssim(np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                                    np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0, 2)
                                    )"""

            """sharpe_out = sharpe_out.reshape(-1,3, sx, sy).detach().cpu().numpy()
            sharpe_data = sharpe_data.reshape(-1, 3, sx, sy).cpu().numpy()
            for idx in range(sharpe_out.shape[0]):
                # import pdb; pdb.set_trace()
                psnr_ += psnr(sharpe_data[idx], sharpe_out[idx])
                dssim_ += dssim(np.swapaxes(sharpe_data[idx], 2, 0), np.swapaxes(sharpe_out[idx], 2, 0))"""

            # psnr_ /= sharpe_out.shape[0]
            # dssim_ /= sharpe_out.shape[0]
            running_loss += loss.item()
            loss_str = ''
            total_steps += B*S
            for key in loss_tracker.keys():
               loss_str += ' {0} : {1:6.4f} '.format(key, 1.0*loss_tracker[key] / total_steps)

            # set display info
            if train_idx % 5 == 0:
                tqdm_loader.set_description(('\r[Training] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '.format
                                    (epoch, running_loss / total_steps,
                                     psnr_ / total_steps,
                                     dssim_ / total_steps) + loss_str
                                    ))

                tqdm_loader.update(5)
        tqdm_loader.close()'''

        # Validation
        running_loss_test = 0.0
        psnr_test = 0.0
        dssim_test = 0.0
        # print('len', len(test_loader))
        tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150)
        # import pdb; pdb.set_trace()

        loss_tracker_test = {}
        for loss_fn in criterion.keys():
            loss_tracker_test[loss_fn] = 0.0

        with torch.no_grad():
            model.eval()
            total_steps_test = 0.0

            for test_idx, data in enumerate(test_loader, 1):
                loss = 0.0
                blur_data, sharpe_data = data
                interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
                # input(interp_idx)
                if args.decode_mode == 'interp':
                    sharpe_data = sharpe_data[:, :, 1::2, :, :]
                elif args.decode_mode == 'deblur':
                    sharpe_data = sharpe_data[:, :, 0::2, :, :]
                else:
                    # print('\nBoth\n')
                    sharpe_data = sharpe_data

                # print(sharpe_data.shape)
                # input(blur_data.shape)
                blur_data = blur_data.to(device)[:, :, :, :352, :].permute(
                    0, 1, 2, 4, 3)
                try:
                    sharpe_data = sharpe_data.squeeze().to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
                except:
                    sharpe_data = sharpe_data.squeeze(3).to(
                        device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

                # clear gradient
                optimizer.zero_grad()

                # forward pass
                sharpe_out = model(blur_data)
                # import pdb; pdb.set_trace()
                # input(sharpe_out.shape)

                # compute losses
                sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
                B, C, S, Fx, Fy = sharpe_out.shape
                for loss_fn in criterion.keys():
                    loss_tmp = 0.0
                    if loss_fn == 'Perceptual':
                        for bidx in range(B):
                            loss_tmp += criterion_w[loss_fn] * \
                                        criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                           sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                        # loss_tmp /= B
                    else:
                        loss_tmp = criterion_w[loss_fn] * \
                                   criterion[loss_fn](sharpe_out, sharpe_data)
                    loss += loss_tmp
                    try:
                        loss_tracker_test[loss_fn] += loss_tmp.item()
                    except KeyError:
                        loss_tracker_test[loss_fn] = loss_tmp.item()

                if ((test_idx % args.progress_iter) == args.progress_iter - 1):
                    itr = test_idx + epoch * len(test_loader)
                    # itr_train
                    writer.add_scalars(
                        'Loss', {
                            'trainLoss': running_loss / total_steps,
                            'validationLoss':
                            running_loss_test / total_steps_test
                        }, itr)
                    writer.add_scalar('Train PSNR', psnr_ / total_steps, itr)
                    writer.add_scalar('Test PSNR',
                                      psnr_test / total_steps_test, itr)
                    # import pdb; pdb.set_trace()
                    # writer.add_image('Validation', sharpe_out.permute(0, 2, 3, 1), itr)

                # statistics
                sharpe_out = sharpe_out.detach().cpu().numpy()
                sharpe_data = sharpe_data.cpu().numpy()
                for sidx in range(S):
                    for bidx in range(B):
                        psnr_test += psnr(
                            sharpe_out[bidx, :, sidx, :, :],
                            sharpe_data[bidx, :, sidx, :, :])  #, peak=1.0)
                        dssim_test += dssim(
                            np.moveaxis(sharpe_out[bidx, :, sidx, :, :], 0, 2),
                            np.moveaxis(sharpe_data[bidx, :, sidx, :, :], 0,
                                        2))  #,range=1.0  )

                running_loss_test += loss.item()
                total_steps_test += B * S
                loss_str = ''
                for key in loss_tracker.keys():
                    loss_str += ' {0} : {1:6.4f} '.format(
                        key, 1.0 * loss_tracker_test[key] / total_steps_test)

                # set display info

                tqdm_loader_test.set_description((
                    '\r[Test    ] [Ep {0:6d}] loss: {1:6.4f} PSNR: {2:6.4f} SSIM: {3:6.4f} '
                    .format(epoch, running_loss_test / total_steps_test,
                            psnr_test / total_steps_test,
                            dssim_test / total_steps_test) + loss_str))
                tqdm_loader_test.update(1)
            tqdm_loader_test.close()

        # save model
        if psnr_old < (psnr_test / total_steps_test):
            if epoch != 1:
                os.remove(
                    os.path.join(
                        args.checkpoint_dir,
                        'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                            epoch_old,
                            str(round(psnr_old, 4)).replace('.', 'pt'),
                            str(round(dssim_old, 4)).replace('.', 'pt'))))
            epoch_old = epoch
            psnr_old = psnr_test / total_steps_test
            dssim_old = dssim_test / total_steps_test

            checkpoint_dict = {
                'epoch': epoch_old,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'train_psnr': psnr_ / total_steps,
                'train_dssim': dssim_ / total_steps,
                'train_mse': loss_tracker['MSE'] / total_steps,
                'train_l1': loss_tracker['L1'] / total_steps,
                # 'train_percp': loss_tracker['Perceptual'] / total_steps,
                'test_psnr': psnr_old,
                'test_dssim': dssim_old,
                'test_mse': loss_tracker_test['MSE'] / total_steps_test,
                'test_l1': loss_tracker_test['L1'] / total_steps_test,
                # 'test_percp': loss_tracker_test['Perceptual'] / total_steps_test,
            }

            torch.save(
                checkpoint_dict,
                os.path.join(
                    args.checkpoint_dir,
                    'epoch-{}-test-psnr-{}-ssim-{}.ckpt'.format(
                        epoch_old,
                        str(round(psnr_old, 4)).replace('.', 'pt'),
                        str(round(dssim_old, 4)).replace('.', 'pt'))))

        # if epoch % args.checkpoint_epoch == 0:
        #    torch.save(model.state_dict(),args.checkpoint_dir + str(int(epoch/100))+".ckpt")

    return None
예제 #30
0
def main(args):
    writer = SummaryWriter(os.path.join('./logs'))
    # torch.backends.cudnn.benchmark = False
    # if not os.path.isdir(args.checkpoint_dir):
    #     os.mkdir(args.checkpoint_dir)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print('[MODEL] CUDA DEVICE : {}'.format(device))

    # TODO DEFINE TRAIN AND TEST TRANSFORMS
    train_tf = None
    test_tf = None

    # Channel wise mean calculated on adobe240-fps training dataset
    mean = [0.429, 0.431, 0.397]
    std = [1, 1, 1]
    normalize = transforms.Normalize(mean=mean,
                                     std=std)
    transform = transforms.Compose([transforms.ToTensor(), normalize])

    test_valid = 'validation' if args.valid else 'test'
    # train_data = BlurDataset(os.path.join(args.dataset_root, 'train'),
    #                         seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf)
    test_data = BlurDataset(os.path.join(args.dataset_root, test_valid),
                            seq_len=args.sequence_length, tau=args.num_frame_blur, delta=5, transform=train_tf, return_path=True)

    # train_loader = DataLoader(train_data, batch_size=args.train_batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=args.test_batch_size, shuffle=False)

    # TODO IMPORT YOUR CUSTOM MODEL
    model = UNet(3, 3, device, decode_mode=args.decode_mode)

    if args.checkpoint:
        store_dict = torch.load(args.checkpoint)
        try:
            print('Loading checkpoint...')
            model.load_state_dict(store_dict['state_dict'])
            print('Done.')
        except KeyError:
            print('Loading checkpoint...')
            model.load_state_dict(store_dict)
            print('Done.')

    model.to(device)
    model.train(False)

    # model = nn.DataParallel(model)

    # TODO DEFINE MORE CRITERIA
    # input(True if device == torch.device('cuda:0') else False)
    criterion = {
                  'MSE': nn.MSELoss(),
                  'L1' : nn.L1Loss(),
                  # 'Perceptual': PerceptualLoss(model='net-lin', net='vgg', dataparallel=False,
                  #                             use_gpu=True if device == torch.device('cuda:0') else False)
                  }


    # Validation
    running_loss_test = 0.0
    psnr_test = 0.0
    dssim_test = 0.0

    tqdm_loader_test = tqdm(range(len(test_loader)), ncols=150)

    loss_tracker_test = {}
    for loss_fn in criterion.keys():
        loss_tracker_test[loss_fn] = 0.0

    with torch.no_grad():
        model.eval()
        total_steps_test = 0.0
        interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
        for test_idx, data in enumerate(test_loader, 1):
            loss = 0.0
            blur_data, sharpe_data, sharp_names = data
            import pdb; pdb.set_trace()
            interp_idx = int(math.ceil((args.num_frame_blur / 2) - 0.49))
            # input(interp_idx)
            if args.decode_mode == 'interp':
                sharpe_data = sharpe_data[:, :, 1::2, :, :]
            elif args.decode_mode == 'deblur':
                sharpe_data = sharpe_data[:, :, 0::2, :, :]
            else:
                # print('\nBoth\n')
                sharpe_data = sharpe_data

            # print(sharpe_data.shape)
            # input(blur_data.shape)
            blur_data = blur_data.to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            try:
                sharpe_data = sharpe_data.squeeze().to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)
            except:
                sharpe_data = sharpe_data.squeeze(3).to(device)[:, :, :, :352, :].permute(0, 1, 2, 4, 3)

            # forward pass
            sharpe_out = model(blur_data).float()

            # compute losses
            sharpe_out = sharpe_out.permute(0, 2, 1, 3, 4)
            B, C, S, Fx, Fy = sharpe_out.shape
            for loss_fn in criterion.keys():
                loss_tmp = 0.0
                if loss_fn == 'Perceptual':
                    for bidx in range(B):
                        loss_tmp += criterion_w[loss_fn] * \
                                    criterion[loss_fn](sharpe_out[bidx].permute(1, 0, 2, 3),
                                                       sharpe_data[bidx].permute(1, 0, 2, 3)).sum()
                    # loss_tmp /= B
                else:
                    loss_tmp = criterion_w[loss_fn] * \
                               criterion[loss_fn](sharpe_out, sharpe_data)
                loss += loss_tmp
                try:
                    loss_tracker_test[loss_fn] += loss_tmp.item()
                except KeyError:
                    loss_tracker_test[loss_fn] = loss_tmp.item()

            # statistics
            #sharpe_out = sharpe_out.detach().cpu().numpy()
            #sharpe_data = sharpe_data.cpu().numpy()
            #  import pdb; pdb.set_trace()
            # t_grid = torchvision.utils.make_grid(torch.stack([blur_data[0], sharpe_out[0], sharpe_data[0]], dim=0),
            #                                    nrow=3)
            # tsave(t_grid, './imgs/{}/combined.jpg'.format(test_idx))
            for sidx in range(S):
                for bidx in range(B):
                    if not os.path.exists('./imgs/{}'.format(sharp_names[1])):
                        os.makedirs('./imgs/{}'.format(test_idx))
                    blur_path = './imgs/{}/blur_input_{}.jpg'.format(test_idx, sidx)

                    # import pdb; pdb.set_trace()
                    # torchvision.utils.save_image(sharpe_out[bidx, :, sidx, :, :],blur_path, normalize=True, range=(0,255));

                    imsave(blur_data, blur_path, bidx, sidx)

                    sharp_path = './imgs/{}/sharpe_gt_{}{}.jpg'.format(test_idx, sidx, sidx)
                    imsave(sharpe_data, sharp_path, bidx, sidx)

                    deblur_path = './imgs/{}/out_{}{}.jpg'.format(test_idx, sidx, sidx)
                    imsave(sharpe_out, deblur_path, bidx, sidx)

                    if sidx > 0 and sidx < S:
                        interp_path = './imgs/{}/out_{}{}.jpg'.format(test_idx, sidx-1, sidx)
                        imsave(sharpe_out, interp_path, bidx, sidx)
                        sharp_path = './imgs/{}/sharpe_gt_{}{}.jpg'.format(test_idx, sidx-1, sidx)
                        imsave(sharpe_data, sharp_path, bidx, sidx)

                    psnr_local = psnr(im_nm * sharpe_out[bidx, :, sidx, :, :].detach().cpu().numpy(),
                                      im_nm * sharpe_data[bidx, :, sidx, :, :].cpu().numpy())
                    dssim_local = dssim(np.moveaxis(im_nm * sharpe_out[bidx, :, sidx, :, :].cpu().numpy(), 0, 2),
                                        np.moveaxis(im_nm * sharpe_data[bidx, :, sidx, :, :].cpu().numpy(), 0, 2)
                                        )
                    psnr_test += psnr_local
                    dssim_test += dssim_local
            f = open('./imgs/{0}/psnr-{1:.4f}-dssim-{2:.4f}.txt'.format(test_idx, psnr_local/(B), dssim_local/(B)),'w')
            f.close()
            running_loss_test += loss.item()
            total_steps_test += B*S
            loss_str = ''
            for key in loss_tracker_test.keys():
                loss_str += ' {0} : {1:6.4f} '.format(key, 1.0 * loss_tracker_test[key] / total_steps_test)

            # set display info

            tqdm_loader_test.set_description(
                        ('\r[Test    ] loss: {0:6.4f} PSNR: {1:6.4f} SSIM: {2:6.4f} '.format
                         ( running_loss_test / total_steps_test,
                          psnr_test / total_steps_test,
                          dssim_test / total_steps_test
                          ) + loss_str
                         )
                    )
            tqdm_loader_test.update(1)
        tqdm_loader_test.close()
    return None