コード例 #1
0
class Solver(object):
	def __init__(self, config, train_loader, valid_loader, test_loader):

		# Data loader
		self.train_loader = train_loader
		self.valid_loader = valid_loader
		self.test_loader = test_loader

		# Models
		self.unet = None
		self.optimizer = None
		self.img_ch = config.img_ch
		self.output_ch = config.output_ch
		#self.criterion = torch.nn.BCELoss()
		self.criterion = nn.CrossEntropyLoss()
		self.augmentation_prob = config.augmentation_prob

		# Hyper-parameters
		self.lr = config.lr
		self.beta1 = config.beta1
		self.beta2 = config.beta2

		# Training settings
		self.num_epochs = config.num_epochs
		self.num_epochs_decay = config.num_epochs_decay
		self.batch_size = config.batch_size

		# Step size
		self.log_step = config.log_step
		self.val_step = config.val_step

		# Path
		self.model_path = config.model_path
		self.result_path = config.result_path
		self.mode = config.mode

		self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
		self.model_type = config.model_type
		self.t = config.t
		self.build_model()

	def build_model(self):
		"""Build generator and discriminator."""
		if self.model_type =='U_Net':
			self.unet = U_Net(img_ch=self.img_ch,output_ch=self.output_ch)
		elif self.model_type =='R2U_Net':
			self.unet = R2U_Net(img_ch=self.img_ch,output_ch=self.output_ch,t=self.t)
		elif self.model_type =='AttU_Net':
			self.unet = AttU_Net(img_ch=self.img_ch,output_ch=self.output_ch)
		elif self.model_type == 'R2AttU_Net':
			self.unet = R2AttU_Net(img_ch=self.img_ch,output_ch=self.output_ch,t=self.t)
			

		self.optimizer = optim.Adam(list(self.unet.parameters()),
					    self.lr, [self.beta1, self.beta2])
		self.unet.to(self.device)

		# self.print_network(self.unet, self.model_type)

	def print_network(self, model, name):
		"""Print out the network information."""
		num_params = 0
		for p in model.parameters():
			num_params += p.numel()
		print(model)
		print(name)
		print("The number of parameters: {}".format(num_params))

	def to_data(self, x):
		"""Convert variable to tensor."""
		if torch.cuda.is_available():
			x = x.cpu()
		return x.data

	def update_lr(self, g_lr, d_lr):
		for param_group in self.optimizer.param_groups:
			param_group['lr'] = lr

	def reset_grad(self):
		"""Zero the gradient buffers."""
		self.unet.zero_grad()

	def compute_accuracy(self,SR,GT):
		SR_flat = SR.view(-1)
		GT_flat = GT.view(-1)

		acc = GT_flat.data.cpu()==(SR_flat.data.cpu()>0.5)

	def tensor2img(self,x):
		img = (x[:,0,:,:]>x[:,1,:,:]).float()
		img = img*255
		return img


	def train(self):
		"""Train encoder, generator and discriminator."""

		#====================================== Training ===========================================#
		#===========================================================================================#
		
		unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,self.lr,self.num_epochs_decay,self.augmentation_prob))

		# U-Net Train
		if os.path.isfile(unet_path):
			# Load the pretrained Encoder
			self.unet.load_state_dict(torch.load(unet_path))
			print('%s is Successfully Loaded from %s'%(self.model_type,unet_path))
		else:
			# Train for Encoder
			lr = self.lr
			best_unet_score = 0.
			
			for epoch in range(self.num_epochs):

				self.unet.train(True)
				epoch_loss = 0
				
				acc = 0.	# Accuracy
				SE = 0.		# Sensitivity (Recall)
				SP = 0.		# Specificity
				PC = 0. 	# Precision
				F1 = 0.		# F1 Score
				JS = 0.		# Jaccard Similarity
				DC = 0.		# Dice Coefficient
				length = 0

				for i, (images, GT) in enumerate(self.train_loader):
					# GT : Ground Truth

					images = images.to(self.device)
					GT = GT.to(self.device)

					# SR : Segmentation Result
					SR = self.unet(images)
					#SR_probs = F.sigmoid(SR)
					#print(SR_probs.size()); exit(1)
					#SR_flat = SR_probs.view(SR_probs.size(0),-1)

					#GT_flat = GT.view(GT.size(0),-1)
					loss = self.criterion(SR, GT)
					epoch_loss += loss.item()

					# Backprop + optimize
					self.reset_grad()
					loss.backward()
					self.optimizer.step()

					acc += get_accuracy(SR,GT)
					SE += get_sensitivity(SR,GT)
					SP += get_specificity(SR,GT)
					PC += get_precision(SR,GT)
					F1 += get_F1(SR,GT)
					JS += get_JS(SR,GT)
					DC += get_DC(SR,GT)
					length += images.size(0)

				acc = acc/length
				SE = SE/length
				SP = SP/length
				PC = PC/length
				F1 = F1/length
				JS = JS/length
				DC = DC/length

				# Print the log info
				print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
					  epoch+1, self.num_epochs, \
					  epoch_loss,\
					  acc,SE,SP,PC,F1,JS,DC))

			

				# Decay learning rate
				if (epoch+1) > (self.num_epochs - self.num_epochs_decay):
					lr -= (self.lr / float(self.num_epochs_decay))
					for param_group in self.optimizer.param_groups:
						param_group['lr'] = lr
					print ('Decay learning rate to lr: {}.'.format(lr))
				
				
				#===================================== Validation ====================================#
				self.unet.train(False)
				self.unet.eval()

				acc = 0.	# Accuracy
				SE = 0.		# Sensitivity (Recall)
				SP = 0.		# Specificity
				PC = 0. 	# Precision
				F1 = 0.		# F1 Score
				JS = 0.		# Jaccard Similarity
				DC = 0.		# Dice Coefficient
				length=0
				for i, (images, GT) in enumerate(self.valid_loader):

					images = images.to(self.device)
					GT = GT.to(self.device)
					#SR = F.sigmoid(self.unet(images))
					SR = self.unet(images)
					acc += get_accuracy(SR,GT)
					SE += get_sensitivity(SR,GT)
					SP += get_specificity(SR,GT)
					PC += get_precision(SR,GT)
					F1 += get_F1(SR,GT)
					JS += get_JS(SR,GT)
					DC += get_DC(SR,GT)
						
					length += images.size(0)
					
				acc = acc/length
				SE = SE/length
				SP = SP/length
				PC = PC/length
				F1 = F1/length
				JS = JS/length
				DC = DC/length
				unet_score = JS + DC

				print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'%(acc,SE,SP,PC,F1,JS,DC))
				
				'''
				torchvision.utils.save_image(images.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_image.png'%(self.model_type,epoch+1)))
				torchvision.utils.save_image(SR.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_SR.png'%(self.model_type,epoch+1)))
				torchvision.utils.save_image(GT.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_GT.png'%(self.model_type,epoch+1)))
				'''


				# Save Best U-Net model
				if unet_score > best_unet_score:
					best_unet_score = unet_score
					best_epoch = epoch
					best_unet = self.unet.state_dict()
					print('Best %s model score : %.4f'%(self.model_type,best_unet_score))
					torch.save(best_unet,unet_path)
					
			#===================================== Test ====================================#
			del self.unet
			del best_unet
			self.build_model()
			self.unet.load_state_dict(torch.load(unet_path))
			
			self.unet.train(False)
			self.unet.eval()

			acc = 0.	# Accuracy
			SE = 0.		# Sensitivity (Recall)
			SP = 0.		# Specificity
			PC = 0. 	# Precision
			F1 = 0.		# F1 Score
			JS = 0.		# Jaccard Similarity
			DC = 0.		# Dice Coefficient
			length=0
			for i, (images, GT) in enumerate(self.valid_loader):

				images = images.to(self.device)
				GT = GT.to(self.device)
				#SR = F.sigmoid(self.unet(images))
				SR = self.unet(images)
				acc += get_accuracy(SR,GT)
				SE += get_sensitivity(SR,GT)
				SP += get_specificity(SR,GT)
				PC += get_precision(SR,GT)
				F1 += get_F1(SR,GT)
				JS += get_JS(SR,GT)
				DC += get_DC(SR,GT)
						
				length += images.size(0)
					
			acc = acc/length
			SE = SE/length
			SP = SP/length
			PC = PC/length
			F1 = F1/length
			JS = JS/length
			DC = DC/length
			unet_score = JS + DC


			f = open(os.path.join(self.result_path,'result.csv'), 'a', encoding='utf-8', newline='')
			wr = csv.writer(f)
			wr.writerow([self.model_type,acc,SE,SP,PC,F1,JS,DC,self.lr,best_epoch,self.num_epochs,self.num_epochs_decay,self.augmentation_prob])
			f.close()
コード例 #2
0
    #   - For 2 classes, use n_classes=1
    #   - For N > 2 classes, use n_classes=N
    #net = UNet(n_channels=3, n_classes=1, bilinear=True)
    #net = R2U_Net(n_channels=1, n_classes=1, bilinear=True)
    net = AttU_Net()
    '''
    logging.info(f'Network:\n'
                 f'\t{net.n_channels} input channels\n'
                 f'\t{net.n_classes} output channels (classes)\n'
                 f'\t{"Bilinear" if net.bilinear else "Transposed conv"} upscaling')
    '''
    if args.load:
        net.load_state_dict(torch.load(args.load, map_location=device))
        logging.info(f'Model loaded from {args.load}')

    net.to(device=device)
    # faster convolutions, but more memory
    # cudnn.benchmark = True

    try:
        train_net(net=net,
                  epochs=args.epochs,
                  batch_size=args.batchsize,
                  lr=args.lr,
                  device=device,
                  img_scale=args.scale,
                  val_percent=args.val / 100)
    except KeyboardInterrupt:
        torch.save(net.state_dict(), 'INTERRUPTED.pth')
        logging.info('Saved interrupt')
        try:
コード例 #3
0
ファイル: solver.py プロジェクト: zkqiu/Image_Segmentation
class Solver(object):
    def __init__(self, config, train_loader, valid_loader, test_loader):

        # Data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        self.criterion = dice_loss()
        self.augmentation_prob = config.augmentation_prob

        # Hyper-parameters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.batch_size = config.batch_size

        # Step size
        self.log_step = config.log_step
        self.val_step = config.val_step

        # Path
        self.model_path = config.model_path
        self.result_path = config.result_path
        self.mode = config.mode

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config.model_type
        self.t = config.t
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=2)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        self.unet.to(self.device)

        # self.print_network(self.unet, self.model_type)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def update_lr(self, g_lr, d_lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.unet.zero_grad()

    def compute_accuracy(self, SR, GT):
        SR_flat = SR.view(-1)
        GT_flat = GT.view(-1)

        acc = GT_flat.data.cpu() == (SR_flat.data.cpu() > 0.5)

    def tensor2img(self, x):
        img = (x[:, 0, :, :] > x[:, 1, :, :]).float()
        img = img * 255
        return img

    def train(self):
        """Train encoder, generator and discriminator."""

        #====================================== Training ===========================================#
        #===========================================================================================#

        unet_path = os.path.join(
            self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %
            (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay,
             self.augmentation_prob))

        # U-Net Train
        if os.path.isfile(unet_path):
            # Load the pretrained Encoder
            self.unet.load_state_dict(torch.load(unet_path))
            print('%s is Successfully Loaded from %s' %
                  (self.model_type, unet_path))
        else:
            # Train for Encoder
            lr = self.lr
            best_unet_score = 0.

            for epoch in range(self.num_epochs):

                self.unet.train(True)
                epoch_loss = 0

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0

                for i, (images, GT) in enumerate(self.train_loader):
                    # GT : Ground Truth
                    if i == 1:
                        gt = GT.numpy()
                        # print (gt.max())
                    images = images.to(self.device)
                    GT = GT.to(self.device).squeeze(1)

                    # SR : Segmentation Result
                    SR = self.unet(images)
                    SR_probs = F.softmax(SR, dim=1)
                    # SR_probs = F.softmax(SR)
                    SR_flat = SR_probs

                    GT_flat = GT

                    #print(SR_flat.requires_grad,GT_flat.requires_grad)
                    loss = self.criterion(SR_flat, GT_flat.long())
                    #print(loss)
                    epoch_loss += loss.item()

                    # Backprop + optimize
                    self.reset_grad()
                    loss.backward()
                    self.optimizer.step()

                    acc += get_accuracy(SR_probs[:, 1:2, :, :], GT)
                    SE += get_sensitivity(SR_probs[:, 1:2, :, :], GT)
                    SP += get_specificity(SR_probs[:, 1:2, :, :], GT)
                    PC += get_precision(SR_probs[:, 1:2, :, :], GT)
                    F1 += get_F1(SR_probs[:, 1:2, :, :], GT)
                    JS += get_JS(SR_probs[:, 1:2, :, :], GT)
                    DC += get_DC(SR_probs[:, 1:2, :, :], GT)
                    length += images.size(0)
                    if i % 50 == 0 or i == len(self.train_loader) - 1:
                        vis_dir = './train_viz/'

                        probs = SR_probs
                        y = GT

                        images = images.permute(0, 2, 3, 1)
                        os.system('rm -rf %s' % (vis_dir))
                        os.system('mkdir %s' % (vis_dir))
                        for j in range(0, images.size()[0]):
                            img = images[
                                j, :, :, :].data.cpu().numpy().squeeze()
                            img -= np.min(img)
                            img /= np.max(img) / 255.
                            img = img[:, :, ::-1]
                            img = img.astype(np.uint8)
                            #img = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_GRAY2BGR)

                            viz = viz_img(img, y[j, :, :], probs[j, :, :, :])

                            path = vis_dir + '%d.jpg' % (j)

                            cv2.imwrite(path, viz)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                epoch_loss = epoch_loss / length
                # Print the log info
                print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
                   epoch+1, self.num_epochs, \
                   epoch_loss,\
                   acc,SE,SP,PC,F1,JS,DC))

                # Decay learning rate
                if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
                    lr -= (self.lr / float(self.num_epochs_decay))
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Decay learning rate to lr: {}.'.format(lr))

                #===================================== Validation ====================================#
                self.unet.train(False)
                self.unet.eval()

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                for i, (images, GT) in enumerate(self.valid_loader):

                    images = images.to(self.device)
                    GT = GT.to(self.device).squeeze(1)
                    SR = (self.unet(images))
                    SR_probs = F.softmax(SR, dim=1)
                    acc += get_accuracy(SR_probs[:, 1:2, :, :], GT)
                    SE += get_sensitivity(SR_probs[:, 1:2, :, :], GT)
                    SP += get_specificity(SR_probs[:, 1:2, :, :], GT)
                    PC += get_precision(SR_probs[:, 1:2, :, :], GT)
                    F1 += get_F1(SR_probs[:, 1:2, :, :], GT)
                    JS += get_JS(SR_probs[:, 1:2, :, :], GT)
                    DC += get_DC(SR_probs[:, 1:2, :, :], GT)

                    length += images.size(0)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                unet_score = JS + DC

                print(
                    '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'
                    % (acc, SE, SP, PC, F1, JS, DC))
                '''
				torchvision.utils.save_image(images.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_image.png'%(self.model_type,epoch+1)))
				torchvision.utils.save_image(SR.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_SR.png'%(self.model_type,epoch+1)))
				torchvision.utils.save_image(GT.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_GT.png'%(self.model_type,epoch+1)))
				'''

                #Save Best U-Net model
                if unet_score > best_unet_score:
                    best_unet_score = unet_score
                    best_epoch = epoch
                    best_unet = self.unet.state_dict()
                    print('Best %s model score : %.4f' %
                          (self.model_type, best_unet_score))
                    torch.save(best_unet, unet_path)

            #===================================== Test ====================================#
            del self.unet
            # del best_unet
            self.build_model()
            self.unet.load_state_dict(torch.load(unet_path))

            self.unet.train(False)
            self.unet.eval()
コード例 #4
0
ファイル: solver.py プロジェクト: 4m4n5/Image_Segmentation
class Solver(object):
    def __init__(self, config, train_loader, valid_loader, test_loader):

        # Data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        self.bce_loss = torch.nn.BCELoss()
        self.augmentation_prob = config.augmentation_prob

        # Hyper-parameters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.lamda = config.lamda

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.batch_size = config.batch_size
        self.save_model = config.save_model

        # Plots
        self.loss_history = hl.History()
        self.acc_history = hl.History()
        self.dc_history = hl.History()
        self.canvas = hl.Canvas()

        # Step size for plotting
        self.log_step = config.log_step
        self.val_step = config.val_step

        # Paths
        self.model_path = config.model_path
        self.result_path = config.result_path
        self.mode = config.mode

        # Model training properties
        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config.model_type
        self.t = config.t
        self.build_model()

    def build_model(self):
        # Load required model
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)

        # Load optimizer
        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        # Move model to device
        self.unet.to(self.device)

    def print_network(self, model, name):
        # Print out the network information
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def dice_loss(self, pred, target):
        pred = pred.view(32, -1)
        target = target.view(32, -1)
        numerator = 2 * torch.sum(pred * target)
        denominator = torch.sum(pred + target)
        return 1 - (numerator + 1) / (denominator + 1)

    def train(self):

        # Debugging (Uncomment following lines)
        # a = torch.zeros((4, 3, 224, 224))
        # self.unet(a.to(self.device))

        unet_path = os.path.join(self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %(self.model_type,self.num_epochs,\
                                                                             self.lr,self.num_epochs_decay,\
                                                                             self.augmentation_prob))

        # U-Net Train
        if os.path.isfile(unet_path):
            # Load the pretrained Encoder
            self.unet.load_state_dict(torch.load(unet_path))
            print('%s is Successfully Loaded from %s' %
                  (self.model_type, unet_path))
        else:
            # Train for Encoder
            lr = self.lr
            best_unet_score = 0.

            for epoch in range(self.num_epochs):

                self.unet.train(True)
                epoch_loss = 0

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0

                for i, (images, GT) in enumerate(self.train_loader):
                    # GT : Ground Truth
                    images = images.to(self.device)
                    GT = GT.to(self.device)

                    # Zero grad
                    self.optimizer.zero_grad()

                    # SR : Segmentation Result
                    SR = self.unet(images)
                    SR_probs = torch.sigmoid(SR)

                    # Convert to 1D tensor for loss calculation
                    SR_flat = SR_probs.view(SR_probs.size(0), -1)
                    GT_flat = GT.view(GT.size(0), -1)

                    # Compute loss
                    loss = self.bce_loss(
                        SR_flat, GT_flat) + self.lamda * self.dice_loss(
                            SR_flat, GT_flat)
                    epoch_loss += loss.item()

                    # Backprop
                    loss.backward()
                    self.optimizer.step()

                    # Get metrics
                    acc += get_accuracy(SR_probs, GT)
                    SE += get_sensitivity(SR_probs, GT)
                    SP += get_specificity(SR_probs, GT)
                    PC += get_precision(SR_probs, GT)
                    F1 += get_F1(SR_probs, GT)
                    JS += get_JS(SR_probs, GT)
                    DC += get_DC(SR_probs, GT)
                    length = i

                length = (i + 1)
                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length

                train_dc = DC
                train_acc = acc
                train_loss = epoch_loss / length

                #                 # Decay learning rate
                #                 if (epoch+1) > (self.num_epochs - self.num_epochs_decay):
                #                     lr -= (self.lr / float(self.num_epochs_decay))
                #                     for param_group in self.optimizer.param_groups:
                #                         param_group['lr'] = lr
                #                     print ('Decay learning rate to lr: {}.'.format(lr))

                # VALIDATION
                with torch.no_grad():
                    epoch_loss = 0
                    self.unet.train(False)
                    self.unet.eval()

                    acc = 0.  # Accuracy
                    SE = 0.  # Sensitivity (Recall)
                    SP = 0.  # Specificity
                    PC = 0.  # Precision
                    F1 = 0.  # F1 Score
                    JS = 0.  # Jaccard Similarity
                    DC = 0.  # Dice Coefficient
                    length = 0
                    for i, (images, GT) in enumerate(self.valid_loader):

                        images = images.to(self.device)
                        GT = GT.to(self.device)
                        SR = torch.sigmoid(self.unet(images))

                        # Convert to 1D tensor for loss calculation
                        SR_flat = SR.view(SR.size(0), -1)
                        GT_flat = GT.view(GT.size(0), -1)

                        # Compute loss
                        loss = self.bce_loss(
                            SR_flat, GT_flat) + self.lamda * self.dice_loss(
                                SR_flat, GT_flat)
                        epoch_loss += loss.item()

                        acc += get_accuracy(SR, GT)
                        SE += get_sensitivity(SR, GT)
                        SP += get_specificity(SR, GT)
                        PC += get_precision(SR, GT)
                        F1 += get_F1(SR, GT)
                        JS += get_JS(SR, GT)
                        DC += get_DC(SR, GT)

                        length = i

                    length = (i + 1)
                    acc = acc / length
                    SE = SE / length
                    SP = SP / length
                    PC = PC / length
                    F1 = F1 / length
                    JS = JS / length
                    DC = DC / length
                    unet_score = JS + DC

                    valid_dc = DC
                    valid_acc = acc
                    valid_loss = epoch_loss / length

                    self.loss_history.log(epoch + 1,
                                          train_loss=train_loss,
                                          valid_loss=valid_loss)
                    self.acc_history.log(epoch + 1,
                                         train_acc=train_acc,
                                         valid_acc=valid_acc)
                    self.dc_history.log(epoch + 1,
                                        train_dc=train_dc,
                                        valid_dc=valid_dc)

                    with self.canvas:
                        self.canvas.draw_plot(
                            [
                                self.loss_history['train_loss'],
                                self.loss_history['valid_loss']
                            ],
                            labels=['Train Loss', 'Valid loss'])
                        self.canvas.draw_plot(
                            [
                                self.acc_history['train_acc'],
                                self.acc_history['valid_acc']
                            ],
                            labels=['Train Acc', 'Valid Acc'])
                        self.canvas.draw_plot(
                            [
                                self.dc_history['train_dc'],
                                self.dc_history['valid_dc']
                            ],
                            labels=['Train Dice Coeff', 'Valid Dice Coeff'])

                    grid_images = torch.cat([(images + 1) / 2,
                                             torch.cat([SR, SR, SR], dim=1),
                                             torch.cat([GT, GT, GT], dim=1)])
                    grid = torchvision.utils.make_grid(grid_images, nrow=4)
                    torchvision.utils.save_image(grid, \
                                                  os.path.join(self.result_path,'%s_valid_%d_image.png'%\
                                                               (self.model_type,epoch+1)))
                    # Save Best U-Net model
                    if self.save_model:
                        if unet_score > best_unet_score:
                            best_unet_score = unet_score
                            best_epoch = epoch
                            best_unet = self.unet.state_dict()
                            print('Best %s model score : %.4f' %
                                  (self.model_type, best_unet_score))
                            torch.save(best_unet, unet_path)

    def test(self):
        del self.unet
        del best_unet
        self.build_model()
        self.unet.load_state_dict(torch.load(unet_path))

        self.unet.train(False)
        self.unet.eval()

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0
        for i, (images, GT) in enumerate(self.valid_loader):

            images = images.to(self.device)
            GT = GT.to(self.device)

            SR = torch.sigmoid(self.unet(images))
            acc += get_accuracy(SR, GT)
            SE += get_sensitivity(SR, GT)
            SP += get_specificity(SR, GT)
            PC += get_precision(SR, GT)
            F1 += get_F1(SR, GT)
            JS += get_JS(SR, GT)
            DC += get_DC(SR, GT)

            length += images.size(0)

        acc = acc / length
        SE = SE / length
        SP = SP / length
        PC = PC / length
        F1 = F1 / length
        JS = JS / length
        DC = DC / length
        unet_score = JS + DC

        f = open(os.path.join(self.result_path, 'result.csv'),
                 'a',
                 encoding='utf-8',
                 newline='')
        wr = csv.writer(f)
        wr.writerow([
            self.model_type, acc, SE, SP, PC, F1, JS, DC, self.lr, best_epoch,
            self.num_epochs, self.num_epochs_decay, self.augmentation_prob
        ])
        f.close()
コード例 #5
0
class Solver(object):
    def __init__(self, config, train_loader, valid_loader, test_loader):

        # Data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        self.criterion = torch.nn.BCELoss()
        self.augmentation_prob = config.augmentation_prob

        # Hyper-parameters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.batch_size = config.batch_size

        # Step size
        self.log_step = config.log_step
        self.val_step = config.val_step

        # Path
        self.model_path = config.model_path
        self.result_path = config.result_path
        self.mode = config.mode

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config.model_type
        self.t = config.t
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        self.unet.to(self.device)

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.unet.zero_grad()

    def train(self):
        """Train encoder, generator and discriminator."""

        #====================================== Training ===========================================#
        #===========================================================================================#

        unet_path = os.path.join(
            self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %
            (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay,
             self.augmentation_prob))
        print(unet_path)

        # U-Net Train
        if os.path.isfile(unet_path):
            # Load the pretrained Encoder
            self.unet.load_state_dict(torch.load(unet_path))
            print('%s is Successfully Loaded from %s' %
                  (self.model_type, unet_path))

        else:
            # Train for Encoder
            lr = self.lr
            best_unet_score = 0.

            for epoch in range(self.num_epochs):

                self.unet.train(True)
                epoch_loss = 0

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0

                for i, (images, GT) in enumerate(self.train_loader):
                    # GT : Ground Truth

                    images = images.to(self.device)
                    GT = GT.to(self.device)

                    # SR : Segmentation Result
                    SR = self.unet(images)
                    SR_probs = torch.sigmoid(SR)
                    SR_flat = SR_probs.view(SR_probs.size(0), -1)

                    GT_flat = GT.view(GT.size(0), -1)
                    loss = self.criterion(SR_flat, GT_flat)
                    epoch_loss += loss.item()

                    # Backprop + optimize
                    self.reset_grad()
                    loss.backward()
                    self.optimizer.step()

                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)
                    length += images.size(0)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length

                # Print the log info
                print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
                   epoch+1, self.num_epochs, \
                   epoch_loss,\
                   acc,SE,SP,PC,F1,JS,DC))

                # Decay learning rate
                if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
                    lr -= (self.lr / float(self.num_epochs_decay))
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Decay learning rate to lr: {}.'.format(lr))

                #===================================== Validation ====================================#
                self.unet.train(False)
                self.unet.eval()

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                for i, (images, GT) in enumerate(self.valid_loader):

                    images = images.to(self.device)
                    GT = GT.to(self.device)
                    SR = torch.sigmoid(self.unet(images))
                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)

                    length += images.size(0)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                unet_score = JS + DC

                print(
                    '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'
                    % (acc, SE, SP, PC, F1, JS, DC))

                torchvision.utils.save_image(
                    images.data.cpu(),
                    os.path.join(
                        self.result_path, '%s_valid_%d_image.png' %
                        (self.model_type, epoch + 1)))
                torchvision.utils.save_image(
                    SR.data.cpu(),
                    os.path.join(
                        self.result_path,
                        '%s_valid_%d_SR.png' % (self.model_type, epoch + 1)))
                torchvision.utils.save_image(
                    GT.data.cpu(),
                    os.path.join(
                        self.result_path,
                        '%s_valid_%d_GT.png' % (self.model_type, epoch + 1)))

                #Para guardar modelo y pesos ---- ACTUAL
                epoca_actual = epoch
                model_actual = self.unet.state_dict()
                print('Actual %s model score : %.4f' %
                      (self.model_type, best_unet_score))
                torch.save(
                    {
                        'epoch': epoca_actual,
                        'model_state_dict': model_actual,
                        ###'optimizer_state_dict': optimizer.state_dict(), ## no reconoce
                        'loss': loss
                    },
                    unet_path)

                # Save Best U-Net model ---- solo si es mejor que el modelo anterior
                if unet_score > best_unet_score:
                    best_unet_score = unet_score
                    best_epoch = epoch
                    best_unet = self.unet.state_dict()
                    print('Best %s model score : %.4f' %
                          (self.model_type, best_unet_score))
                    torch.save(best_unet, unet_path)

            #===================================== Test ====================================#
            del self.unet
            del best_unet
            self.build_model()
            self.unet.load_state_dict(torch.load(unet_path))

            self.unet.train(False)
            self.unet.eval()

            acc = 0.  # Accuracy
            SE = 0.  # Sensitivity (Recall)
            SP = 0.  # Specificity
            PC = 0.  # Precision
            F1 = 0.  # F1 Score
            JS = 0.  # Jaccard Similarity
            DC = 0.  # Dice Coefficient
            length = 0
            for i, (images, GT) in enumerate(self.valid_loader):

                images = images.to(self.device)
                GT = GT.to(self.device)
                SR = torch.sigmoid(self.unet(images))
                acc += get_accuracy(SR, GT)
                SE += get_sensitivity(SR, GT)
                SP += get_specificity(SR, GT)
                PC += get_precision(SR, GT)
                F1 += get_F1(SR, GT)
                JS += get_JS(SR, GT)
                DC += get_DC(SR, GT)

                length += images.size(0)

            acc = acc / length
            SE = SE / length
            SP = SP / length
            PC = PC / length
            F1 = F1 / length
            JS = JS / length
            DC = DC / length
            unet_score = JS + DC

            f = open(os.path.join(self.result_path, 'result.csv'),
                     'a',
                     encoding='utf-8',
                     newline='')
            wr = csv.writer(f)
            wr.writerow([
                self.model_type, acc, SE, SP, PC, F1, JS, DC, self.lr,
                best_epoch, self.num_epochs, self.num_epochs_decay,
                self.augmentation_prob
            ])
            f.close()
コード例 #6
0
class Solver(object):
    def __init__(self, config, train_loader, valid_loader, test_loader):

        # Data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        self.criterion = torch.nn.BCELoss()
        self.augmentation_prob = config.augmentation_prob

        # Hyper-parameters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.batch_size = config.batch_size

        # Step size
        self.log_step = config.log_step
        self.val_step = config.val_step

        # Path
        self.model_path = config.model_path
        self.result_path = config.result_path
        self.mode = config.mode

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config.model_type
        self.t = config.t
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3, output_ch=1, t=self.t)

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        self.unet.to(self.device)

        # self.print_network(self.unet, self.model_type)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def update_lr(self, g_lr, d_lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.unet.zero_grad()

    def compute_accuracy(self, SR, GT):
        SR_flat = SR.view(-1)
        GT_flat = GT.view(-1)

        acc = GT_flat.data.cpu() == (SR_flat.data.cpu() > 0.5)

    def tensor2img(self, x):
        img = (x[:, 0, :, :] > x[:, 1, :, :]).float()
        img = img * 255
        return img

    def train(self, pretrain, pre_bestscore):
        """Train encoder, generator and discriminator."""

        #====================================== Training ===========================================#
        #===========================================================================================#
        if pretrain == 0:
            unet_path = os.path.join(
                self.model_path, '%s-%d-%.4f-%d-%.4f.pkl' %
                (self.model_type, self.num_epochs, self.lr,
                 self.num_epochs_decay, self.augmentation_prob))
        else:
            unet_path = self.model_path
        #print(unet_path)
        # U-Net Train
        if os.path.isfile(unet_path):
            # Load the pretrained Encoder
            self.unet.load_state_dict(torch.load(unet_path))
            print('%s is Successfully Loaded from %s' %
                  (self.model_type, unet_path))
            # Train for Encoder
            lr = self.lr
            best_unet_score = pre_bestscore

            for epoch in range(self.num_epochs):

                self.unet.train(True)
                epoch_loss = 0

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                #print(self.train_loader)
                for i, (images, GT, _, _) in enumerate(self.train_loader):
                    # GT : Ground Truth
                    #print(i, (images, GT))
                    images = images.to(self.device)
                    GT = GT.to(self.device)
                    #print(images.shape, GT.shape)
                    # SR : Segmentation Result
                    SR = self.unet(images)
                    #print(SR.shape)
                    SR_probs = F.sigmoid(SR)
                    SR_flat = SR_probs.view(SR_probs.size(0), -1)

                    GT_flat = GT.view(GT.size(0), -1)

                    loss = self.criterion(SR_flat, GT_flat)
                    epoch_loss += loss.item()

                    # Backprop + optimize
                    self.reset_grad()
                    loss.backward()
                    self.optimizer.step()

                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)
                    length += images.size(0)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length

                # Print the log info
                print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
                      epoch+1, self.num_epochs, \
                      epoch_loss,\
                      acc,SE,SP,PC,F1,JS,DC))

                # Decay learning rate
                if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
                    lr -= (self.lr / float(self.num_epochs_decay))
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Decay learning rate to lr: {}.'.format(lr))

                #===================================== Validation ====================================#
                self.unet.train(False)
                self.unet.eval()

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                for i, (images, GT, _, _) in enumerate(self.valid_loader):

                    images = images.to(self.device)
                    GT = GT.to(self.device)
                    SR = self.unet(images)
                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)

                    length += images.size(0)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                unet_score = JS + DC

                print(
                    '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'
                    % (acc, SE, SP, PC, F1, JS, DC))
                '''
                torchvision.utils.save_image(images.data.cpu(),
                                            os.path.join(self.result_path,
                                                        '%s_valid_%d_image.png'%(self.model_type,epoch+1)))
                torchvision.utils.save_image(SR.data.cpu(),
                                            os.path.join(self.result_path,
                                                        '%s_valid_%d_SR.png'%(self.model_type,epoch+1)))
                torchvision.utils.save_image(GT.data.cpu(),
                                            os.path.join(self.result_path,
                                                        '%s_valid_%d_GT.png'%(self.model_type,epoch+1)))
                '''

                # Save Best U-Net model
                if unet_score > best_unet_score:
                    best_unet_score = unet_score
                    best_epoch = epoch
                    best_unet = self.unet.state_dict()
                    premodel_unet_path = unet_path[:-4] + '_pretrained' + '.pkl'
                    print(
                        'Best %s model score : %.4f unet_path is ' %
                        (self.model_type, best_unet_score), premodel_unet_path)
                    torch.save(best_unet, premodel_unet_path)

        else:
            # Train for Encoder
            lr = self.lr
            best_unet_score = 0.

            for epoch in range(self.num_epochs):

                self.unet.train(True)
                epoch_loss = 0

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                #print(self.train_loader)
                for i, (images, GT, _, _) in enumerate(self.train_loader):
                    # GT : Ground Truth
                    #print(i, (images, GT))
                    images = images.to(self.device)
                    GT = GT.to(self.device)

                    # SR : Segmentation Result
                    SR = self.unet(images)
                    SR_probs = F.sigmoid(SR)
                    SR_flat = SR_probs.view(SR_probs.size(0), -1)

                    GT_flat = GT.view(GT.size(0), -1)
                    loss = self.criterion(SR_flat, GT_flat)
                    epoch_loss += loss.item()

                    # Backprop + optimize
                    self.reset_grad()
                    loss.backward()
                    self.optimizer.step()

                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)
                    length += images.size(0)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length

                # Print the log info
                print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
                      epoch+1, self.num_epochs, \
                      epoch_loss,\
                      acc,SE,SP,PC,F1,JS,DC))

                # Decay learning rate
                if (epoch + 1) > (self.num_epochs - self.num_epochs_decay):
                    lr -= (self.lr / float(self.num_epochs_decay))
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr
                    print('Decay learning rate to lr: {}.'.format(lr))

                #===================================== Validation ====================================#
                self.unet.train(False)
                self.unet.eval()

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                length = 0
                for i, (images, GT, _, _) in enumerate(self.valid_loader):

                    images = images.to(self.device)
                    GT = GT.to(self.device)
                    SR = self.unet(images)
                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC += get_DC(SR, GT)

                    length += images.size(0)

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                unet_score = JS + DC

                print(
                    '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'
                    % (acc, SE, SP, PC, F1, JS, DC))
                '''
                torchvision.utils.save_image(images.data.cpu(),
                                            os.path.join(self.result_path,
                                                        '%s_valid_%d_image.png'%(self.model_type,epoch+1)))
                torchvision.utils.save_image(SR.data.cpu(),
                                            os.path.join(self.result_path,
                                                        '%s_valid_%d_SR.png'%(self.model_type,epoch+1)))
                torchvision.utils.save_image(GT.data.cpu(),
                                            os.path.join(self.result_path,
                                                        '%s_valid_%d_GT.png'%(self.model_type,epoch+1)))
                '''

                # Save Best U-Net model
                if unet_score > best_unet_score:
                    best_unet_score = unet_score
                    best_epoch = epoch
                    best_unet = self.unet.state_dict()
                    print('Best %s model score : %.4f' %
                          (self.model_type, best_unet_score))
                    torch.save(best_unet, unet_path)

            #===================================== Test ====================================#
            del self.unet
            del best_unet
            self.build_model()
            self.unet.load_state_dict(torch.load(unet_path))

            self.unet.train(False)
            self.unet.eval()

            acc = 0.  # Accuracy
            SE = 0.  # Sensitivity (Recall)
            SP = 0.  # Specificity
            PC = 0.  # Precision
            F1 = 0.  # F1 Score
            JS = 0.  # Jaccard Similarity
            DC = 0.  # Dice Coefficient
            length = 0
            for i, (images, GT, _, _) in enumerate(self.valid_loader):

                images = images.to(self.device)
                GT = GT.to(self.device)
                SR = self.unet(images)
                acc += get_accuracy(SR, GT)
                SE += get_sensitivity(SR, GT)
                SP += get_specificity(SR, GT)
                PC += get_precision(SR, GT)
                F1 += get_F1(SR, GT)
                JS += get_JS(SR, GT)
                DC += get_DC(SR, GT)

                length += images.size(0)

            acc = acc / length
            SE = SE / length
            SP = SP / length
            PC = PC / length
            F1 = F1 / length
            JS = JS / length
            DC = DC / length
            unet_score = JS + DC

            f = open(os.path.join(self.result_path, 'result.csv'),
                     'a',
                     encoding='utf-8',
                     newline='')
            wr = csv.writer(f)
            wr.writerow([
                self.model_type, acc, SE, SP, PC, F1, JS, DC, self.lr,
                best_epoch, self.num_epochs, self.num_epochs_decay,
                self.augmentation_prob
            ])
            f.close()

    def test(self,
             unet_path,
             result_savepath,
             mask_savepath,
             pre_savepath,
             threshold=0.5):
        self.build_model()
        self.unet.load_state_dict(torch.load(unet_path))

        self.unet.train(False)
        self.unet.eval()

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        length = 0
        num_recall = 0
        rm_mkdir(result_savepath)
        rm_mkdir(mask_savepath)
        rm_mkdir(pre_savepath)
        for i, (images, GT, HW, filename) in enumerate(self.test_loader):

            images = images.to(self.device)
            GT = GT.to(self.device)
            SR = self.unet(images)
            acc += get_accuracy(SR, GT, threshold)
            SE += get_sensitivity(SR, GT, threshold)
            SP += get_specificity(SR, GT, threshold)
            PC += get_precision(SR, GT, threshold)
            F1 += get_F1(SR, GT, threshold)
            JS += get_JS(SR, GT, threshold)
            DC += get_DC(SR, GT, threshold)

            GT_class = torch.max(GT).int()
            SR_class = torch.max(SR > threshold)
            GT_class = GT_class.type_as(SR_class)
            # recall positive
            if SR_class > 0:
                num_recall += 1
                #print(GT_class, SR_class)
                SR_PIL_img = saveImg(SR, HW)
                GT_PIL_img = saveImg_GT(GT, HW)
                images_PIL_img = saveImg_contour(images, HW)
                #filename = self.test_loader.dataset.image_paths[i].split('/')[-1][:-4]
                #print(filename[0])
                SR_PIL_img.save(pre_savepath + filename[0] + ".png")
                GT_PIL_img.save(mask_savepath + filename[0] + "_mask.png")
                images_PIL_img.save(result_savepath + filename[0] + ".png")
            length += images.size(0)
#             images = images.cpu().numpy()
#new_img_PIL = torchvision.transforms.ToPILImage()(images[0,:,:,:]).convert('RGB')
#             scipy.misc.imsave('outfile.jpg', images)
#SR_PIL_img = saveImg(SR, HW)

        acc = acc / length
        SE = SE / length
        SP = SP / length
        PC = PC / length
        F1 = F1 / length
        JS = JS / length
        DC = DC / length
        unet_score = JS + DC
        recall = num_recall / length
        print('acc:{} DC:{} F1:{}'.format(acc, DC, F1))
        print('length:{} num_recall:{} recall:{}'.format(
            length, num_recall, recall))

    def predict(self,
                unet_path,
                raw_savepath,
                pre_savepath,
                zeros_savepath,
                threshold=0.5):
        self.build_model()
        self.unet.load_state_dict(torch.load(unet_path))

        self.unet.train(False)
        self.unet.eval()

        num_recall = 0
        length = 0
        rm_mkdir(raw_savepath)
        rm_mkdir(pre_savepath)
        rm_mkdir(zeros_savepath)
        zeors_image = np.zeros(shape=(512, 512, 3))
        for i, (images, GT, HW, filename) in enumerate(self.test_loader):

            images = images.to(self.device)
            SR = self.unet(images)

            SR_class = torch.max(SR > threshold)

            if SR_class > 0:
                num_recall += 1
                images_PIL_img = saveImg_contour(images, HW)
                SR_PIL_img = saveImg(SR, HW)
                #filename = self.test_loader.dataset.image_paths[i].split('/')[-1][:-4]
                images_PIL_img.save(raw_savepath + filename[0] + ".png")
                SR_PIL_img.save(pre_savepath + filename[0] + ".png")
            else:
                scipy.misc.imsave(zeros_savepath + filename[0] + ".png",
                                  zeors_image)
            length += images.size(0)
            #print(HW)

        recall = num_recall / length
        print('length:{} num_recall:{} recall:{}'.format(
            length, num_recall, recall))
コード例 #7
0
ファイル: solver.py プロジェクト: warmtub/Image_Segmentation
class Solver(object):
    def __init__(self, config, train_valid_loader):

        # Data loader
        self.train_valid_loader = train_valid_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        self.criterion = torch.nn.BCELoss()
        self.augmentation_prob = config.augmentation_prob

        # Hyper-parameters
        self.lr = config.lr
        self.beta1 = config.beta1
        self.beta2 = config.beta2

        # Training settings
        self.n_splits = config.n_splits
        self.num_epochs = config.num_epochs
        self.num_epochs_decay = config.num_epochs_decay
        self.batch_size = config.batch_size

        # Step size
        self.log_step = config.log_step
        self.val_step = config.val_step

        # Path
        self.model_path = config.model_path
        self.result_path = config.result_path
        self.mode = config.mode

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config.model_type
        self.t = config.t
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=3, output_ch=self.output_ch, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=3, output_ch=self.output_ch)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=3,
                                   output_ch=self.output_ch,
                                   t=self.t)

        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])
        self.unet.to(self.device)

        # self.print_network(self.unet, self.model_type)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    def update_lr(self, g_lr, d_lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.unet.zero_grad()

    def reset_model(self):
        for layer in self.unet.children():
            if hasattr(layer, 'reset_parameters'):
                layer.reset_parameters()
        self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr,
                                    [self.beta1, self.beta2])

    def compute_accuracy(self, SR, GT):
        SR_flat = SR.view(-1)
        GT_flat = GT.view(-1)

        acc = GT_flat.data.cpu() == (SR_flat.data.cpu() > 0.5)

    def tensor2img(self, x):
        img = (x[:, 0, :, :] > x[:, 1, :, :]).float()
        img = img * 255
        return img

    def train(self):
        """Train encoder, generator and discriminator."""

        #====================================== Training ===========================================#
        #===========================================================================================#

        # U-Net Train
        if False:
            pass
        #if os.path.isfile(unet_path):
        #    # Load the pretrained Encoder
        #    self.unet.load_state_dict(torch.load(unet_path))
        #    print('%s is Successfully Loaded from %s'%(self.model_type,unet_path))
        else:
            # Train for Encoder
            kfold = KFold(n_splits=self.n_splits, shuffle=True)
            for fold, (train_index, valid_index) in enumerate(
                    kfold.split(self.train_valid_loader.dataset)):

                print(f"Fold{fold} start")
                logging.info(f"Fold{fold} start")
                logging.info(f"train: {train_index} valid: {valid_index}")
                self.reset_model()
                lr = self.lr
                best_unet_score = 0.

                self.unet.train(True)
                epoch_loss = 0

                for epoch in range(int(self.num_epochs / self.n_splits)):

                    acc = 0.  # Accuracy
                    SE = 0.  # Sensitivity (Recall)
                    SP = 0.  # Specificity
                    PC = 0.  # Precision
                    F1 = 0.  # F1 Score
                    JS = 0.  # Jaccard Similarity
                    DC = 0.  # Dice Coefficient
                    length = 0

                    # GT : Ground Truth
                    train_sampler = SubsetRandomSampler(train_index)
                    valid_sampler = SubsetRandomSampler(valid_index)

                    train_loader = torch.utils.data.DataLoader(
                        self.train_valid_loader.dataset,
                        batch_size=self.batch_size,
                        sampler=train_sampler)
                    valid_loader = torch.utils.data.DataLoader(
                        self.train_valid_loader.dataset,
                        batch_size=self.batch_size,
                        sampler=valid_sampler)

                    for i, (images, GT) in enumerate(train_loader):

                        images = images.to(self.device)
                        GT = GT.to(self.device)

                        # SR : Segmentation Result
                        SR = self.unet(images)
                        SR_probs = torch.sigmoid(SR)
                        SR_flat = SR_probs.view(SR_probs.size(0), -1)

                        GT_flat = GT.view(GT.size(0), -1)
                        loss = self.criterion(SR_flat, GT_flat)
                        epoch_loss += loss.item()

                        # Backprop + optimize
                        self.reset_grad()
                        loss.backward()
                        self.optimizer.step()

                        acc += get_accuracy(SR, GT)
                        SE += get_sensitivity(SR, GT)
                        SP += get_specificity(SR, GT)
                        PC += get_precision(SR, GT)
                        F1 += get_F1(SR, GT)
                        JS += get_JS(SR, GT)
                        DC += get_DC(SR, GT)
                        length += images.size(0)

                    acc = acc / length
                    SE = SE / length
                    SP = SP / length
                    PC = PC / length
                    F1 = F1 / length
                    JS = JS / length
                    DC = DC / length
                    print("training ", DC, length)

                    # Print the log info
                    print('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
                          fold*self.num_epochs/self.n_splits+epoch+1, self.num_epochs, \
                          epoch_loss,\
                          acc,SE,SP,PC,F1,JS,DC))
                    logging.info('Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f' % (
                          fold*self.num_epochs/self.n_splits+epoch+1, self.num_epochs, \
                          epoch_loss,\
                          acc,SE,SP,PC,F1,JS,DC))

                    # Decay learning rate
                    if (epoch * self.n_splits + fold +
                            1) > (self.num_epochs - self.num_epochs_decay):
                        lr -= (self.lr / float(self.num_epochs_decay))
                        for param_group in self.optimizer.param_groups:
                            param_group['lr'] = lr
                        print('Decay learning rate to lr: {}.'.format(lr))

                    #===================================== Validation ====================================#
                    self.unet.train(False)
                    self.unet.eval()

                    acc = 0.  # Accuracy
                    SE = 0.  # Sensitivity (Recall)
                    SP = 0.  # Specificity
                    PC = 0.  # Precision
                    F1 = 0.  # F1 Score
                    JS = 0.  # Jaccard Similarity
                    DC = 0.  # Dice Coefficient
                    length = 0
                    for i, (images, GT) in enumerate(valid_loader):

                        images = images.to(self.device)
                        GT = GT.to(self.device)
                        SR = torch.sigmoid(self.unet(images))
                        acc += get_accuracy(SR, GT)
                        SE += get_sensitivity(SR, GT)
                        SP += get_specificity(SR, GT)
                        PC += get_precision(SR, GT)
                        F1 += get_F1(SR, GT)
                        JS += get_JS(SR, GT)
                        DC += get_DC(SR, GT)

                        length += images.size(0)

                    acc = acc / length
                    SE = SE / length
                    SP = SP / length
                    PC = PC / length
                    F1 = F1 / length
                    JS = JS / length
                    DC = DC / length
                    unet_score = JS + DC
                    print("valid ", DC, length)

                    print(
                        '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'
                        % (acc, SE, SP, PC, F1, JS, DC))
                    logging.info(
                        '[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f'
                        % (acc, SE, SP, PC, F1, JS, DC))
                    '''
                    torchvision.utils.save_image(images.data.cpu(),
                                                os.path.join(self.result_path,
                                                            '%s_valid_%d_image.png'%(self.model_type,epoch+1)))
                    torchvision.utils.save_image(SR.data.cpu(),
                                                os.path.join(self.result_path,
                                                            '%s_valid_%d_SR.png'%(self.model_type,epoch+1)))
                    torchvision.utils.save_image(GT.data.cpu(),
                                                os.path.join(self.result_path,
                                                            '%s_valid_%d_GT.png'%(self.model_type,epoch+1)))
                    '''

                    # Save Best U-Net model
                    print(f'model score: {unet_score} ({best_unet_score})')
                    logging.info(
                        (f'model score: {unet_score} ({best_unet_score})'))
                    if unet_score > best_unet_score:
                        best_unet_score = unet_score
                        best_epoch = fold * self.num_epochs / self.n_splits + epoch + 1
                        best_unet = self.unet.state_dict()
                        print('Best %s model score : %.4f' %
                              (self.model_type, best_unet_score))
                        logging.info('Best %s model score : %.4f' %
                                     (self.model_type, best_unet_score))
                        unet_path = os.path.join(
                            self.model_path, '%s-f%d-%d-%.3f.pkl' %
                            (self.model_type, fold, best_epoch, DC))
                        torch.save(best_unet, unet_path)

            #===================================== Test ====================================#
            """
コード例 #8
0
class Solver(object):
    def __init__(self, config, train_loader, valid_loader, test_loader,
                 whole_slice_prediction_loader):

        # Data loader
        self.train_loader = train_loader
        self.valid_loader = valid_loader
        self.test_loader = test_loader
        self.whole_slice_prediction_loader = whole_slice_prediction_loader

        # Models
        self.unet = None
        self.optimizer = None
        self.img_ch = config.img_ch
        self.output_ch = config.output_ch
        #self.criterion = torch.nn.BCELoss()
        self.augmentation_prob = config.augmentation_prob
        self.inverse_ratio = config.inverse_ratio

        # Hyper-parameters
        self.initial_lr = config.lr
        self.current_lr = config.lr

        self.optimizer_choice = config.optimizer_choice
        if config.optimizer_choice == 'Adam':
            self.beta1 = config.beta1
            self.beta2 = config.beta2
        elif config.optimizer_choice == 'SGD':
            self.momentum = config.momentum
        else:
            print('No such optimizer available')

        # Training settings
        self.num_epochs = config.num_epochs
        #self.num_epochs_decay = config.num_epochs_decay
        self.batch_size = config.batch_size
        self.PPorLS = config.PPorLS

        # Step size
        self.log_step = config.log_step
        self.val_step = config.val_step
        self.batch_val_num = config.val_freq_batch

        # Path
        self.model_path = config.model_path
        self.result_path = config.result_path
        self.result_img_path = config.result_img_path
        self.mode = config.mode

        self.device = torch.device(
            'cuda' if torch.cuda.is_available() else 'cpu')
        self.model_type = config.model_type
        self.t = config.t
        self.build_model()

    def build_model(self):
        """Build generator and discriminator."""
        if self.model_type == 'U_Net':
            self.unet = U_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2U_Net':
            self.unet = R2U_Net(img_ch=1, output_ch=1, t=self.t)
        elif self.model_type == 'AttU_Net':
            self.unet = AttU_Net(img_ch=1, output_ch=1)
        elif self.model_type == 'R2AttU_Net':
            self.unet = R2AttU_Net(img_ch=1, output_ch=1, t=self.t)

        if self.optimizer_choice == 'Adam':
            self.optimizer = optim.Adam(list(self.unet.parameters()),
                                        self.initial_lr,
                                        [self.beta1, self.beta2])
        elif self.optimizer_choice == 'SGD':
            self.optimizer = optim.SGD(list(self.unet.parameters()),
                                       self.initial_lr, self.momentum)
        else:
            pass

        self.unet.to(self.device)

        #self.print_network(self.unet, self.model_type)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))

    def dice_coeff_loss(self, y_pred, y_true):
        smooth = 1
        y_true_flat = y_true.view(y_true.size(0), -1)
        y_pred_flat = y_pred.view(y_pred.size(0), -1)
        intersection = (y_true_flat * y_pred_flat).sum()

        return -(2. * intersection + smooth) / ((y_true_flat).sum() +
                                                (y_pred_flat).sum() + smooth)

    def RR_dice_coeff_loss(self, y_pred, y_true):
        smooth = 1e-6
        y_true_flat = y_true.view(y_true.size(0), -1)
        y_pred_flat = y_pred.view(y_pred.size(0), -1)
        intersection = (y_true_flat * y_pred_flat).sum()

        inverse_y_true_flat = 1 - y_true_flat
        inverse_y_pred_flat = 1 - y_pred_flat
        inverse_intersection = (inverse_y_true_flat *
                                inverse_y_pred_flat).sum()
        return -(2. * intersection + smooth) / (
            (y_true_flat).sum() + (y_pred_flat).sum() +
            smooth) - (2. * inverse_intersection + smooth) / (
                (inverse_y_true_flat).sum() +
                (inverse_y_pred_flat).sum() + smooth)

    def to_data(self, x):
        """Convert variable to tensor."""
        if torch.cuda.is_available():
            x = x.cpu()
        return x.data

    # Redefine the 'update_lr' function (R&R)
    def update_lr(self, new_lr):
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = new_lr

    def run_batch_validation(self, epoch, batch_train):
        self.unet.train(False)
        self.unet.eval()

        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        DC_RR = 0
        length = 0

        validation_batch_loss = 0
        for batch, (images, GT) in enumerate(self.valid_loader):

            images = images.to(self.device)
            GT = GT.to(self.device)
            # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer.
            images = images.reshape(self.batch_size, self.img_ch,
                                    np.shape(images)[1],
                                    np.shape(images)[2])
            GT = GT.reshape(self.batch_size, self.img_ch,
                            np.shape(GT)[1],
                            np.shape(GT)[2])

            #SR = F.sigmoid(self.unet(images))
            SR = torch.sigmoid(self.unet(images))
            acc += get_accuracy(SR, GT)
            SE += get_sensitivity(SR, GT)
            SP += get_specificity(SR, GT)
            PC += get_precision(SR, GT)
            F1 += get_F1(SR, GT)
            JS += get_JS(SR, GT)
            DC_RR += get_DC_RR(SR, GT, inverse_ratio=self.inverse_ratio)
            DC += get_DC(SR, GT)

            length += images.size(0)

            # Compute the validation loss.
            SR = self.unet(images)
            SR_probs = torch.sigmoid(SR)
            SR_flat = SR_probs.view(SR_probs.size(0), -1)
            GT_flat = GT.view(GT.size(0), -1)
            # use the dice coefficient loss instead of the BCE loss. (R&R)
            validation_loss = self.dice_coeff_loss(SR_flat, GT_flat)

            validation_batch_loss += validation_loss.item()

        acc = acc / length
        SE = SE / length
        SP = SP / length
        PC = PC / length
        F1 = F1 / length
        JS = JS / length
        DC = DC / length
        DC_RR = DC_RR / length
        unet_score = DC_RR

        print('current batch: {}'.format(batch_train))
        print('Current learning rate: {}'.format(self.current_lr))

        print(
            'Current Batch [%d] \n[Validation] Validation Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, DC_RR: %.4f'
            % (batch_train + 1, validation_batch_loss, acc, SE, SP, PC, F1, JS,
               DC, DC_RR))

        # Append validation loss to train loss history (R&R)

        f = open(os.path.join(self.result_path,
                              'model_validation_batch_history.csv'),
                 'a',
                 encoding='utf-8',
                 newline='')
        wr = csv.writer(f)
        wr.writerow([
            'Validation',
            'Epoch [%d/%d]' % (epoch + 1, self.num_epochs),
            'Batch [%d]' % (batch_train + 1),
            'Validation loss: %.4f' % validation_batch_loss,
            'Accuracy: %.4f' % acc,
            'Sensitivity: %.4f' % SE,
            'Specificity: %.4f' % SP,
            'Precision: %.4f' % PC,
            'F1 Score: %.4f' % F1,
            'Jaccard Similarity: %.4f' % JS,
            'Dice Coefficient: %.4f' % DC,
            'RR_DC: %.4f' % DC_RR
        ])

        self.unet.train(True)

        return (validation_batch_loss, unet_score)

    # Define adaptive learning rate handler (R&R)
    def adaptive_lr_handler(self, cooldown, min_lr, current_epoch,
                            previous_update_epoch, plateau_ratio,
                            adjustment_ratio, loss_history):
        if current_epoch > 1:
            if current_epoch - previous_update_epoch > cooldown:
                if (loss_history[-1] > loss_history[-2]) or (abs(
                    (loss_history[-2] - loss_history[-1]) / loss_history[-2]) <
                                                             plateau_ratio):
                    if self.current_lr > min_lr:
                        self.current_lr = adjustment_ratio * self.current_lr
                        self.update_lr(self.current_lr)
                        print(
                            'Validation loss stop decreasing. Adjust the learning rate to {}.'
                            .format(self.current_lr))
                        return current_epoch

    def reset_grad(self):
        """Zero the gradient buffers."""
        self.unet.zero_grad()

    def tensor2img(self, x):
        img = (x[:, 0, :, :] > x[:, 1, :, :]).float()
        img = img * 255
        return img

    def train(self):
        """Train encoder, generator and discriminator."""

        #====================================== Training ===========================================#
        #===========================================================================================#

        unet_path = os.path.join(
            self.model_path, '%s-%d-%.4f-%.4f-%s-%s-%.4f.pkl' %
            (self.model_type, self.num_epochs, self.initial_lr,
             self.augmentation_prob, self.PPorLS, self.optimizer_choice,
             self.inverse_ratio))
        print('The U-Net path is {}'.format(unet_path))
        # U-Net Train
        # Train loss history (R&R)
        train_loss_history = []
        train_batch_loss_history = []
        # Validation loss history (R&R)
        validation_loss_history = []
        val_batch_loss_history = []
        stop_training = False

        if os.path.isfile(unet_path):
            # Load the pretrained Encoder
            self.unet.load_state_dict(torch.load(unet_path))
            print('%s is Successfully Loaded from %s' %
                  (self.model_type, unet_path))
        else:
            # Train for Encoder
            best_unet_score = 0.
            print('Start training. The initial learning rate is: {}'.format(
                self.initial_lr))

            for epoch in range(self.num_epochs):
                self.unet.train(True)
                train_epoch_loss = 0
                validation_epoch_loss = 0

                if stop_training == True:
                    break
                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                DC_RR = 0
                length = 0

                for batch, (images, GT) in enumerate(self.train_loader):
                    # GT : Ground Truth
                    images = images.to(self.device)
                    GT = GT.to(self.device)
                    # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer. (R&R)
                    images = images.reshape(self.batch_size, self.img_ch,
                                            np.shape(images)[1],
                                            np.shape(images)[2])
                    GT = GT.reshape(self.batch_size, self.img_ch,
                                    np.shape(GT)[1],
                                    np.shape(GT)[2])

                    # SR : Segmentation Result
                    SR = self.unet(images)
                    SR_probs = torch.sigmoid(SR)
                    SR_flat = SR_probs.view(SR_probs.size(0), -1)

                    GT_flat = GT.view(GT.size(0), -1)
                    # Use dice coefficient loss instead of the BCE loss. (R&R)
                    train_loss = self.dice_coeff_loss(SR_flat, GT_flat)

                    train_epoch_loss += train_loss.item()

                    # Backprop + optimize
                    self.reset_grad()
                    train_loss.backward()
                    self.optimizer.step()

                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC_RR += get_DC_RR(SR,
                                       GT,
                                       inverse_ratio=self.inverse_ratio)
                    DC += get_DC(SR, GT)
                    length += images.size(0)

                    if epoch == 0:
                        val_frequency = self.batch_val_num[0]
                    else:
                        val_frequency = self.batch_val_num[1]

                    if batch % val_frequency == 0:
                        # update learning rate and record the validation loss history
                        validation_batch_loss, unet_score = self.run_batch_validation(
                            epoch, batch)
                        val_batch_loss_history.append(validation_batch_loss)
                        train_batch_loss_history.append(train_epoch_loss)

                        if unet_score > best_unet_score:
                            best_unet_score = unet_score
                            best_epoch = epoch
                            best_unet = self.unet.state_dict()
                            print('Best %s model score : %.4f' %
                                  (self.model_type, best_unet_score))
                            torch.save(best_unet, unet_path)

# update learning rate
                        batch_id = len(val_batch_loss_history)
                        try:
                            previous_batch_id = self.adaptive_lr_handler(
                                3, 0.01 * self.initial_lr, batch_id,
                                previous_batch_id, 0.001, 0.5,
                                val_batch_loss_history)
                        except:
                            previous_batch_id = self.adaptive_lr_handler(
                                3, 0.01 * self.initial_lr, batch_id, 0, 0.001,
                                0.5, val_batch_loss_history)

                        if ((batch_id - 4) % 10 == 0) and (
                                batch_id >
                                8) or unet_score < 0.2 * best_unet_score:
                            if (np.median(val_batch_loss_history[-10:-5]) >=
                                    np.median(val_batch_loss_history[-5:])):
                                print(
                                    'Validation loss stop decreasing. Stop training.'
                                )
                                stop_training = True
                                break

                if stop_training == True:
                    break

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                DC_RR = DC_RR / length

                # Print the log info
                print('Epoch [%d/%d] \n[Training] Train Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, DC_RR: %.4f' % (\
                                    epoch + 1, self.num_epochs, train_epoch_loss,\
                   acc, SE, SP, PC, F1, JS, DC, DC_RR))

                # Append train loss to train loss history (R&R)
                train_loss_history.append(train_epoch_loss)


                f = open(os.path.join(self.result_path, 'train_and_validation_history.csv'), 'a', \
                                     encoding = 'utf-8', newline= '')
                wr = csv.writer(f)
                wr.writerow(['Training', 'Epoch [%d/%d]' % (epoch + 1, self.num_epochs), \
                                         'Train loss: %.4f' % train_epoch_loss,\
                                        'Accuracy: %.4f' % acc, 'Sensitivity: %.4f' % SE, 'Specificity: %.4f' % SP, 'Precision: %.4f'% PC, \
                                        'F1 Score: %.4f' % F1, 'Jaccard Similarity: %.4f' % JS, 'Dice Coefficient: %.4f' % DC, 'RR_DC: %.4f' % DC_RR])
                f.close()

                #===================================== Validation ====================================#
                self.unet.train(False)
                self.unet.eval()

                acc = 0.  # Accuracy
                SE = 0.  # Sensitivity (Recall)
                SP = 0.  # Specificity
                PC = 0.  # Precision
                F1 = 0.  # F1 Score
                JS = 0.  # Jaccard Similarity
                DC = 0.  # Dice Coefficient
                DC_RR = 0
                length = 0

                for batch, (images, GT) in enumerate(self.valid_loader):

                    images = images.to(self.device)
                    GT = GT.to(self.device)
                    # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer.
                    images = images.reshape(self.batch_size, self.img_ch,
                                            np.shape(images)[1],
                                            np.shape(images)[2])
                    GT = GT.reshape(self.batch_size, self.img_ch,
                                    np.shape(GT)[1],
                                    np.shape(GT)[2])

                    #SR = F.sigmoid(self.unet(images))
                    SR = torch.sigmoid(self.unet(images))
                    acc += get_accuracy(SR, GT)
                    SE += get_sensitivity(SR, GT)
                    SP += get_specificity(SR, GT)
                    PC += get_precision(SR, GT)
                    F1 += get_F1(SR, GT)
                    JS += get_JS(SR, GT)
                    DC_RR += get_DC_RR(SR,
                                       GT,
                                       inverse_ratio=self.inverse_ratio)
                    DC += get_DC(SR, GT)

                    length += images.size(0)

                    # Compute the validation loss.
                    SR = self.unet(images)
                    SR_probs = torch.sigmoid(SR)
                    SR_flat = SR_probs.view(SR_probs.size(0), -1)
                    GT_flat = GT.view(GT.size(0), -1)
                    # use the dice coefficient loss instead of the BCE loss. (R&R)
                    validation_loss = self.dice_coeff_loss(SR_flat, GT_flat)

                    validation_epoch_loss += validation_loss.item()

                acc = acc / length
                SE = SE / length
                SP = SP / length
                PC = PC / length
                F1 = F1 / length
                JS = JS / length
                DC = DC / length
                DC_RR = DC_RR / length
                unet_score = DC_RR
                print('Current learning rate: {}'.format(self.current_lr))

                print(
                    'Epoch [%d/%d] \n[Validation] Validation Loss: %.4f, Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, F1: %.4f, JS: %.4f, DC: %.4f, DC_RR: %.4f'
                    % (epoch + 1, self.num_epochs, validation_epoch_loss, acc,
                       SE, SP, PC, F1, JS, DC, DC_RR))

                # Append validation loss to train loss history (R&R)
                validation_loss_history.append(validation_epoch_loss)
                '''
				torchvision.utils.save_image(images.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_image.png'%(self.model_type,epoch+1)))
				torchvision.utils.save_image(SR.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_SR.png'%(self.model_type,epoch+1)))
				torchvision.utils.save_image(GT.data.cpu(),
											os.path.join(self.result_path,
														'%s_valid_%d_GT.png'%(self.model_type,epoch+1)))
				'''

                f = open(os.path.join(self.result_path, 'train_and_validation_history.csv'), 'a', \
                                     encoding = 'utf-8', newline= '')
                wr = csv.writer(f)
                wr.writerow(['Validation', 'Epoch [%d/%d]' % (epoch + 1, self.num_epochs), \
                                         'Validation loss: %.4f' % validation_epoch_loss,\
                                        'Accuracy: %.4f' % acc, 'Sensitivity: %.4f' % SE, 'Specificity: %.4f' % SP, 'Precision: %.4f'% PC, \
                                        'F1 Score: %.4f' % F1, 'Jaccard Similarity: %.4f' % JS, 'Dice Coefficient: %.4f' % DC, 'RR_DC: %.4f' % DC_RR])
                f.close()

                # Save Best U-Net model
                if unet_score > best_unet_score:
                    best_unet_score = unet_score
                    best_epoch = epoch
                    best_unet = self.unet.state_dict()
                    print('Best %s model score : %.4f' %
                          (self.model_type, best_unet_score))
                    torch.save(best_unet, unet_path)

                # Early stop (R&R)
                #if (epoch > 8) and ((epoch - 4) % 5 == 0):
                #	if (np.median(validation_loss_history[-10:-5]) >= np.median(validation_loss_history[-5:])):
                #		print('Validation loss stop decreasing. Stop training.')
                #		break

                if (len(validation_loss_history) > 1):
                    if (validation_loss_history[-2] >=
                            validation_loss_history[-1]):
                        print(
                            'Validation loss stop decreasing. Stop training.')
                        break

        del self.unet
        try:
            del best_unet
        except:
            print(
                'Cannot delete the variable "best_unet": variable does not exist.'
            )

        return train_loss_history, validation_loss_history, val_batch_loss_history, train_batch_loss_history

    def test(self):
        """Test encoder, generator and discriminator."""
        #======================================= Test ====================================#
        #=================================================================================#
        unet_path = os.path.join(
            self.model_path, '%s-%d-%.4f-%.4f-%s-%s-%.4f.pkl' %
            (self.model_type, self.num_epochs, self.initial_lr,
             self.augmentation_prob, self.PPorLS, self.optimizer_choice,
             self.inverse_ratio))
        self.build_model()
        self.unet.load_state_dict(torch.load(unet_path))

        self.unet.train(False)
        self.unet.eval()
        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        DC_RR = 0
        length = 0
        for i, (images, GT) in enumerate(self.test_loader):
            images = images.to(self.device)
            GT = GT.to(self.device)
            # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer.
            images = images.reshape(self.batch_size, self.img_ch,
                                    np.shape(images)[1],
                                    np.shape(images)[2])
            GT = GT.reshape(self.batch_size, self.img_ch,
                            np.shape(GT)[1],
                            np.shape(GT)[2])

            #SR = F.sigmoid(self.unet(images))
            SR = torch.sigmoid(self.unet(images))
            acc += get_accuracy(SR, GT)
            SE += get_sensitivity(SR, GT)
            SP += get_specificity(SR, GT)
            PC += get_precision(SR, GT)
            F1 += get_F1(SR, GT)
            JS += get_JS(SR, GT)
            DC_RR += get_DC_RR(SR, GT, inverse_ratio=self.inverse_ratio)
            DC += get_DC(SR, GT)
            length += images.size(0)
            np_img = np.squeeze(SR.cpu().detach().numpy())
            np.save(self.result_img_path + str(i) + '.npy', np_img)

        acc = acc / length
        SE = SE / length
        SP = SP / length
        PC = PC / length
        F1 = F1 / length
        JS = JS / length
        DC = DC / length
        DC_RR = DC_RR / length

        print('model type: ', self.model_type, 'accuracy: ', acc,
              'sensitivity: ', SE, 'specificity: ', SP, 'precision: ', PC,
              'F1 score: ', F1, 'Jaccard similarity: ', JS,
              'Dice Coefficient: ', DC, 'DC_RR: ', DC_RR)
        result_csv_path = '/home/raphael/Projects/DL-Lung_Nodule_LUNA16/Solutions/RaphaelRosalie-solution/patch-based_U-net/results/'
        f = open(os.path.join(result_csv_path, 'result_compare.csv'),
                 'a',
                 encoding='utf-8',
                 newline='')
        wr = csv.writer(f)
        wr.writerow([self.model_type, self.PPorLS, 'Accuracy: %.4f' % acc, 'Sensitivity: %.4f' % SE, 'Specificity: %.4f' % SP, 'Precision: %.4f'% PC, \
                                  'F1 Score: %.4f' % F1, 'Jaccard Similarity: %.4f' % JS, 'Dice Coefficient: %.4f' % DC, 'RR_DC: %.4f' % DC_RR, 'inverse_ratio: %.3f' % self.inverse_ratio])
        f.close()

    def whole_slice_prediction(self):
        """Inference mode. Return whole slice prediction as a binary nodule mask."""
        unet_path = os.path.join(
            self.model_path, '%s-%d-%.4f-%.4f-%s-%s.pkl' %
            (self.model_type, self.num_epochs, self.initial_lr,
             self.augmentation_prob, self.PPorLS, self.optimizer_choice))
        self.build_model()
        self.unet.load_state_dict(torch.load(unet_path))

        self.unet.train(False)
        self.unet.eval()
        acc = 0.  # Accuracy
        SE = 0.  # Sensitivity (Recall)
        SP = 0.  # Specificity
        PC = 0.  # Precision
        F1 = 0.  # F1 Score
        JS = 0.  # Jaccard Similarity
        DC = 0.  # Dice Coefficient
        DC_RR = 0
        length = 0
        for batch, (images,
                    GT) in enumerate(self.whole_slice_prediction_loader):
            images = images.to(self.device)
            GT = GT.to(self.device)
            # Reshape the images and GT to 4-dimensional so that they can get fed to the conv2d layer.
            images = images.reshape(self.batch_size, self.img_ch,
                                    np.shape(images)[1],
                                    np.shape(images)[2])
            GT = GT.reshape(self.batch_size, self.img_ch,
                            np.shape(GT)[1],
                            np.shape(GT)[2])

            #SR = F.sigmoid(self.unet(images))
            SR = torch.sigmoid(self.unet(images))
            acc += get_accuracy(SR, GT)
            SE += get_sensitivity(SR, GT)
            SP += get_specificity(SR, GT)
            PC += get_precision(SR, GT)
            F1 += get_F1(SR, GT)
            JS += get_JS(SR, GT)
            DC_RR += get_DC_RR(SR, GT)
            DC += get_DC(SR, GT)
            length += images.size(0)
            np_img = np.squeeze(SR.cpu().detach().numpy())
            np.save(self.result_img_path + str(i) + '.npy', np_img)

        acc = acc / length
        SE = SE / length
        SP = SP / length
        PC = PC / length
        F1 = F1 / length
        JS = JS / length
        DC = DC / length
        DC_RR = DC_RR / length
        unet_score = DC