Пример #1
0
def test(args):
    if not args.model:
        print('Need a pretrained model!')
        return

    if not args.color_labels:
        print('Need to specify color labels')
        return

    resize_img = False if args.image_width is None or args.image_height is None else True

    # check if output dir exists
    output_dir = args.output_dir if args.output_dir else 'test-{}'.format(
        utils.get_datetime_string())
    if not os.path.exists(output_dir):
        os.mkdir(output_dir)

    # load model
    model = networks.UNet(args.unet_layers, 3, len(args.color_labels))
    model.load_state_dict(torch.load(args.model))
    model = model.eval()

    if not args.cpu:
        model.cuda()

    # iterate all images with one by one
    transform = torchvision.transforms.ToTensor()
    for filename in [x for x in os.listdir(args.dataroot)]:
        filepath = os.sep.join([args.dataroot, filename])
        with open(filepath, 'r') as f:
            img = Image.open(f)
            img = img.resize((args.image_width, args.image_height))
            img = transform(img)
            img = img.view(1, *img.shape)
            img = Variable(img)
        if not args.cpu:
            img = img.cuda()
        output = model(img)
        _, c, h, w = output.data.shape
        output_numpy = output.data.numpy()[0] if args.cpu else output.data.cpu(
        ).numpy()[0]
        output_argmax = numpy.argmax(output_numpy, axis=0)
        out_img = numpy.zeros((h, w, 3), dtype=numpy.uint8)
        for i, color in enumerate(args.color_labels):
            out_img[output_argmax == i] = numpy.array(args.color_labels[i],
                                                      dtype=numpy.uint8)
        out_img = Image.fromarray(out_img)
        seg_filepath = os.sep.join(
            [output_dir, filename[:filename.rfind('.')] + '.png'])
        out_img.save(seg_filepath)
        print('{} is exported!'.format(seg_filepath))
Пример #2
0
def net_generator():
    gen_net=networks.UNet().to(device)
    optimizer=torch.optim.Adam(
        gen_net.parameters(),lr=0.001)
    
    # optimizer=torch.optim.SGD(direct_intrinsic_net.parameters(),lr=0.01,momentum=0.9)

    # --------------------------------------------------
    print_network(gen_net)

    # --------------------------------------------------
    # checkpoint = torch.load(
    #     "/mnt/1T-5e7/mycodehtml/prac_data_s/kaggle/tgs-salt-identification-challenge/train/checkpoint.pth.tar")

    # gen_net.load_state_dict(checkpoint['state_dict'])
    # optimizer.load_state_dict(checkpoint['optimizer'])
    
    return gen_net,optimizer
Пример #3
0
    def __init__(self, options):
        self.opt = options
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)

        self.models = {}
        self.parameters_to_train = []

        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # self.models["encoder"] = networks.ResnetEncoder(
        #     self.opt.num_layers, pretrained=False)
        # self.models["encoder"].to(self.device)
        # self.parameters_to_train += list(self.models["encoder"].parameters())
        #
        # self.models["decoder"] = networks.Decoder(
        #     self.models["encoder"].num_ch_enc)
        # self.models["decoder"].to(self.device)
        # self.parameters_to_train += list(self.models["decoder"].parameters())

        # Initialize the resnet50 and resnet101 model for this run
        model_50, input_size = self.initialize_model("resnet50",
                                                     num_classes,
                                                     feature_extract,
                                                     use_pretrained=True)
        self.models["resnet50"] = model_50
        self.models["resnet50"].to(self.device)

        model_101, input_size = self.initialize_model("resnet101",
                                                      num_classes,
                                                      feature_extract,
                                                      use_pretrained=True)
        self.models["resnet101"] = model_101
        self.models["resnet101"].to(self.device)

        # self.models["RAN"] = DeepLab_ResNet101_MSC(n_classes=21)
        self.models["RAN"] = RAN(in_channels=2048, out_channels=128)
        self.models["RAN"].to(self.device)

        self.models["unet"] = networks.UNet(n_channels=1, n_classes=4)
        self.models["unet"].to(self.device)

        # self.models["unet"] = networks.UNet(n_channels=1, n_classes=4)
        # self.models["unet"].to(self.device)

        # self.parameters_to_train += list(self.models["unet"].parameters())
        # self.parameters_to_train += list(self.models["resnet50"].parameters())
        # self.parameters_to_train += list(self.models["resnet101"].parameters())

        self.parameters_to_train = nn.Parameter(rescale_transform(
            torch.normal(mean=0.5, std=1, size=(1, 3, 512, 512),
                         device="cuda")),
                                                requires_grad=True)
        '''
        w = Variable(torch.randn(3, 5), requires_grad=True)
        b = Variable(torch.randn(3, 5), requires_grad=True)
        self.parameters_to_train += w
        self.parameters_to_train += b
        '''

        #self.model_optimizer = optim.SGD(self.parameters_to_train,self.opt.learning_rate,momentum=0.9,weight_decay=0.0005)
        self.model_optimizer = optim.Adam([self.parameters_to_train],
                                          self.opt.learning_rate)
        '''
        self.model_optimizer = optim.Adam(self.parameters_to_train,
                                          self.opt.learning_rate)
        '''

        self.dataset = datasets.Retouch_dataset

        if self.opt.use_augmentation:
            self.transform = transforms.Compose([
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomVerticalFlip(p=0.5),
                #transforms.RandomRotation(degrees=(-20, 20)),
            ])
        else:
            self.transform = None

        # self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.opt.ce_weighting).to(self.device),
        #                                      ignore_index=self.opt.ignore_idx)
        self.criterion = nn.CrossEntropyLoss(reduction='none')

        train_dataset = self.dataset(base_dir=self.opt.base_dir,
                                     list_dir=self.opt.list_dir,
                                     split='train',
                                     is_train=True,
                                     transform=self.transform)

        train_dataset = self.dataset(base_dir=self.opt.base_dir,
                                     list_dir=self.opt.list_dir,
                                     split='train',
                                     is_train=True,
                                     transform=self.transform)

        train_dataset = self.dataset(base_dir=self.opt.base_dir,
                                     list_dir=self.opt.list_dir,
                                     split='train',
                                     is_train=True,
                                     transform=self.transform)

        self.train_loader = DataLoader(train_dataset,
                                       self.opt.batch_size,
                                       True,
                                       num_workers=self.opt.num_workers,
                                       pin_memory=True,
                                       drop_last=True)

        val_dataset = self.dataset(base_dir=self.opt.base_dir,
                                   list_dir=self.opt.list_dir,
                                   split='val',
                                   is_train=False,
                                   transform=self.transform)

        self.val_loader = DataLoader(val_dataset,
                                     self.opt.batch_size,
                                     True,
                                     num_workers=self.opt.num_workers,
                                     pin_memory=True,
                                     drop_last=True)

        self.val_iter = iter(self.val_loader)

        num_train_samples = len(train_dataset)
        self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs

        self.writers = {}
        for mode in ["train", "val"]:
            self.writers[mode] = SummaryWriter(
                os.path.join(self.log_path, mode))
Пример #4
0
    def __init__(self, options):
        self.opt = options
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)

        self.models = {}
        self.parameters_to_train = []
        self.parameters_to_train_F = []
        self.parameters_to_train_D = []

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        # self.models["encoder"] = networks.ResnetEncoder(
        #     self.opt.num_layers, pretrained=False)
        # self.models["encoder"].to(self.device)
        # self.parameters_to_train += list(self.models["encoder"].parameters())
        #
        # self.models["decoder"] = networks.Decoder(
        #     self.models["encoder"].num_ch_enc)
        # self.models["decoder"].to(self.device)
        # self.parameters_to_train += list(self.models["decoder"].parameters())

        # Initialize the resnet50 and resnet101 model for this run
        model_50 = self.initialize_model("resnet50", requires_grad=False)
        self.models["resnet50"] = model_50
        self.models["resnet50"].to(self.device)

        model_101 = self.initialize_model("resnet101", requires_grad=False)
        self.models["resnet101"] = model_101
        self.models["resnet101"].to(self.device)

        # self.models["RAN"] = DeepLab_ResNet101_MSC(n_classes=21)
        self.models["RAN"] = RAN(in_channels=512, out_channels=21)
        self.models["RAN"].to(self.device)

        self.models["unet"] = networks.UNet(n_channels=1, n_classes=4)
        self.models["unet"].to(self.device)
        self.parameters_to_train += list(self.models["unet"].parameters())

        # Optimizers
        self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate)
        self.load_model()

        model_unet_encoder = self.initialize_model("unet_encoder", requires_grad=False)
        self.models["unet_encoder"] = model_unet_encoder
        self.models["unet_encoder"].to(self.device)

        self.parameters_to_train_F += list(self.models["unet_encoder"].parameters())
        self.parameters_to_train_D += list(self.models["RAN"].parameters())

        self.optimizer_F = optim.Adam(self.parameters_to_train_F, self.opt.learning_rate)
        self.optimizer_D = optim.Adam(self.parameters_to_train_D, self.opt.learning_rate)

        # self.models["unet_down4"] = UNet_Layer(output_layer='down4')
        # self.models["unet_down4"].to(self.device)
        # self.parameters_to_train += list(self.models["unet"].parameters())
        # self.parameters_to_train += list(self.models["resnet50"].parameters())
        # self.parameters_to_train += list(self.models["resnet101"].parameters())


        '''
        w = Variable(torch.randn(3, 5), requires_grad=True)
        b = Variable(torch.randn(3, 5), requires_grad=True)
        self.parameters_to_train += w
        self.parameters_to_train += b
        '''



        '''
        self.model_optimizer = optim.Adam(self.parameters_to_train,
                                          self.opt.learning_rate)
        '''

        self.dataset = datasets.Retouch_dataset
        self.coco_dataset = datasets.Coco_dataset

        self.transform = None
        '''
        self.transform = transforms.Compose([
            transforms.Normalize(mean=[0.1422], std=[0.0885])
        ])
        '''

        '''
        if self.opt.use_augmentation:
            self.transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
                                                 transforms.RandomVerticalFlip(p=0.5),
                                                 #transforms.RandomRotation(degrees=(-20, 20)),
                                                 ])
        else:
            self.transform = None
        '''

        # self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.opt.ce_weighting).to(self.device),
        #                                      ignore_index=self.opt.ignore_idx)
        self.criterion = nn.CrossEntropyLoss(reduction='none')

        self.source_dataset_AAN, self.source_dir_AAN = self.initialize_dataset_AAN("cirrus_val")
        self.target_dataset_AAN, self.target_dir_AAN = self.initialize_dataset_AAN("spectralis")

        self.source_dataloader_AAN = DataLoader(
            self.source_dataset_AAN,
            1,
            True,
            num_workers=self.opt.num_workers,
            pin_memory=True,
            drop_last=True)

        self.target_dataloader_AAN = DataLoader(
            self.target_dataset_AAN,
            1,
            True,
            num_workers=self.opt.num_workers,
            pin_memory=True,
            drop_last=True)

        train_dataset = self.dataset(
            base_dir=self.opt.base_dir,
            list_dir=self.opt.list_dir,
            split='train',
            is_train=True,
            transform=self.transform)

        self.train_loader = DataLoader(
            train_dataset,
            self.opt.batch_size,
            True,
            num_workers=self.opt.num_workers,
            pin_memory=True,
            drop_last=True)

        val_dataset = self.dataset(
            base_dir=self.opt.base_dir,
            list_dir=self.opt.list_dir,
            split='val',
            is_train=False,
            transform=self.transform)

        self.val_loader = DataLoader(
            val_dataset,
            self.opt.batch_size,
            True,
            num_workers=self.opt.num_workers,
            pin_memory=True,
            drop_last=True)

        self.val_iter = iter(self.val_loader)

        num_train_samples = len(train_dataset)
        self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs

        self.writers = {}
        for mode in ["train", "val", "AAN", "RAN"]:
            self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
Пример #5
0
	def build_model_graph(self):
		print("{}: Start to build model graph...".format(datetime.datetime.now()))

		self.global_step_op = tf.train.get_or_create_global_step()		

		if self.dimension == 2:
			input_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], self.input_channel_num) 
			output_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], 1) 
		elif self.dimension == 3:
			input_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], self.patch_shape[2], self.input_channel_num) 
			output_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], self.patch_shape[2], 1) 
		else:
			sys.exit('Invalid Patch Shape (length should be 2 or 3)')

		self.images_placeholder, self.labels_placeholder = self.placeholder_inputs(input_batch_shape,output_batch_shape)

		# plot input and output images to tensorboard
		if self.image_log:
			if self.dimension == 2:
				for image_channel in range(self.input_channel_num):
					image_log = tf.cast(self.images_placeholder[:,:,:,image_channel:image_channel+1], dtype=tf.uint8)
					tf.summary.image(self.image_filenames[image_channel], image_log, max_outputs=self.batch_size)
				if 0 in self.label_classes:
					labels_log = tf.cast(self.labels_placeholder*math.floor(255/(self.output_channel_num-1)), dtype=tf.uint8)
				else:
					labels_log = tf.cast(self.labels_placeholder*math.floor(255/self.output_channel_num), dtype=tf.uint8)
				tf.summary.image("label",labels_log, max_outputs=self.batch_size)
			else:
				for batch in range(self.batch_size):
					for image_channel in range(self.input_channel_num):
						image_log = tf.cast(self.images_placeholder[batch:batch+1,:,:,:,image_channel], dtype=tf.uint8)
						tf.summary.image(self.image_filenames[image_channel], tf.transpose(image_log,[3,1,2,0]),max_outputs=self.patch_shape[-1])
					if 0 in self.label_classes:
						labels_log = tf.cast(self.labels_placeholder[batch:batch+1,:,:,:,0]*math.floor(255/(self.output_channel_num-1)),dtype=tf.uint8)
					else:
						labels_log = tf.cast(self.labels_placeholder[batch:batch+1,:,:,:,0]*math.floor(255/self.output_channel_num), dtype=tf.uint8)
					tf.summary.image("label", tf.transpose(labels_log,[3,1,2,0]),max_outputs=self.patch_shape[-1])

		# Get images and labels
		# create transformations to image and labels
		# Force input pipepline to CPU:0 to avoid operations sometimes ended up at GPU and resulting a slow down
		with tf.device('/cpu:0'):
			if self.dimension == 2:
				train_transforms_3d = []
				
				train_transforms_2d = [
					NiftiDataset2D.ManualNormalization(0,300),
					NiftiDataset2D.Resample(self.spacing),
					NiftiDataset2D.Padding(self.patch_shape),
					NiftiDataset2D.RandomCrop(self.patch_shape)
				]

				test_transforms_3d = []

				test_transforms_2d = [
					NiftiDataset2D.ManualNormalization(0,300),
					NiftiDataset2D.Resample(self.spacing),
					NiftiDataset2D.Padding(self.patch_shape),
					NiftiDataset2D.RandomCrop(self.patch_shape)
				]

				trainTransforms = {"3D": train_transforms_3d, "2D": train_transforms_2d}
				testTransforms = {"3D": test_transforms_3d, "2D": test_transforms_2d}
			else:
				trainTransforms = [
					# NiftiDataset.Normalization(),
					# NiftiDataset3D.ExtremumNormalization(0.1),
					# NiftiDataset3D.ManualNormalization(0,300),
					NiftiDataset3D.StatisticalNormalization(2.5),
					NiftiDataset3D.Resample((self.spacing[0],self.spacing[1],self.spacing[2])),
					NiftiDataset3D.Padding((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2])),
					NiftiDataset3D.RandomCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),self.drop_ratio, self.min_pixel),
					# NiftiDataset.ConfidenceCrop((FLAGS.patch_size*3, FLAGS.patch_size*3, FLAGS.patch_layer*3),(0.0001,0.0001,0.0001)),
					# NiftiDataset.BSplineDeformation(randomness=2),
					# NiftiDataset.ConfidenceCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),(0.5,0.5,0.5)),
					# NiftiDataset3D.ConfidenceCrop2((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),rand_range=32,probability=0.8),
					# NiftiDataset3D.RandomFlip([True, False, False]),
					NiftiDataset3D.RandomNoise()
					]

				# use random crop for testing
				testTransforms = [
					# NiftiDataset.Normalization(),
					# NiftiDataset3D.ExtremumNormalization(0.1),
					# NiftiDataset3D.ManualNormalization(0,300),
					NiftiDataset3D.StatisticalNormalization(2.5),
					NiftiDataset3D.Resample((self.spacing[0],self.spacing[1],self.spacing[2])),
					NiftiDataset3D.Padding((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2])),
					NiftiDataset3D.RandomCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),self.drop_ratio, self.min_pixel)
					# NiftiDataset.ConfidenceCrop((FLAGS.patch_size*2, FLAGS.patch_size*2, FLAGS.patch_layer*2),(0.0001,0.0001,0.0001)),
					# NiftiDataset.BSplineDeformation(),
					# NiftiDataset.ConfidenceCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),(0.75,0.75,0.75)),
					# NiftiDataset.ConfidenceCrop2((FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer),rand_range=32,probability=0.8),
					# NiftiDataset.RandomFlip([True, False, False]),
					]

			# get input and output datasets
			self.train_iterator = self.dataset_iterator(self.train_data_dir, trainTransforms)
			self.next_element_train = self.train_iterator.get_next()

			if self.testing:
				self.test_iterator = self.dataset_iterator(self.test_data_dir, testTransforms)
				self.next_element_test = self.test_iterator.get_next()

		print("{}: Dataset pipeline complete".format(datetime.datetime.now()))

		# network models:
		if self.network_name == "FCN":
			sys.exit("Network to be developed")
		elif self.network_name == "UNet":
			self.network = networks.UNet(
				num_output_channels=self.output_channel_num,
				dropout_rate=0.01,
				num_channels=4,
				num_levels=4,
				num_convolutions=2,
				bottom_convolutions=2,
				is_training=True,
				activation_fn="relu"
				)
		elif self.network_name =="VNet":
			self.network = networks.VNet(
				num_classes=self.output_channel_num,
				dropout_rate=self.dropout_rate,
				num_channels=16,
				num_levels=4,
				num_convolutions=(1, 2, 3, 3),
				bottom_convolutions=3,
				is_training = True,
				activation_fn="prelu"
				)
		else:
			sys.exit("Invalid Network")

		print("{}: Core network complete".format(datetime.datetime.now()))

		self.logits = self.network.GetNetwork(self.images_placeholder)

		# softmax op
		self.softmax_op = tf.nn.softmax(self.logits,name="softmax")

		if self.image_log:
			if self.dimension == 2:
				for output_channel in range(self.output_channel_num):
					# softmax_log = grayscale_to_rainbow(self.softmax_op[:,:,:,output_channel:output_channel+1])
					softmax_log = self.softmax_op[:,:,:,output_channel:output_channel+1]
					softmax_log = tf.cast(softmax_log*255, dtype = tf.uint8)
					tf.summary.image("softmax_" + str(self.label_classes[output_channel]),softmax_log,max_outputs=self.batch_size)
			else:
				for batch in range(self.batch_size):
					for output_channel in range(self.output_channel_num):
						softmax_log = grayscale_to_rainbow(tf.transpose(self.softmax_op[batch:batch+1,:,:,:,output_channel],[3,1,2,0]))
						softmax_log = tf.cast(softmax_log*255,dtype=tf.uint8)
						tf.summary.image("softmax_" + str(self.label_classes[output_channel]),softmax_log,max_outputs=self.patch_shape[-1])

		print("{}: Output layers complete".format(datetime.datetime.now()))

		# loss function
		with tf.name_scope("loss"):
			# """
			# 	Tricks for faster converge: Here we provide two calculation methods, first one will ignore  to classical dice formula
			# 	method 1: exclude the 0-th label in dice calculation. to use this method properly, you must set 0 as the first value in SegmentationClasses in config.json
			# 	method 2: dice will be average on all classes
			# """
			if self.dimension == 2:
				labels = tf.one_hot(self.labels_placeholder[:,:,:,0], depth=self.output_channel_num)
			else:
				labels = tf.one_hot(self.labels_placeholder[:,:,:,:,0], depth=self.output_channel_num)

			# if 0 in self.label_classes:
			# 	################### method 1 ###################
			# 	if self.dimension ==2:
			# 		labels = labels[:,:,:,1:]
			# 		softmax = self.softmax_op[:,:,:,1:]
			# 	else:
			# 		labels = labels[:,:,:,:,1:]
			# 		softmax = self.softmax_op[:,:,:,:,1:]
			# else:
			# 	################### method 2 ###################
			# 	labels = labels
			# 	softmax = self.softmax_op

			labels = labels
			softmax = self.softmax_op

			if (self.loss_name == "sorensen"):
				if self.dimension == 2:
					sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen',axis=(1,2))
				else:
					sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen')
				self.loss_op = 1. - sorensen
			elif (self.loss_name == "weighted_sorensen"):
				if self.dimension == 2:
					sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen', axis=(1,2), weighted=True)
				else:
					sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen', weighted=True)
				self.loss_op = 1. - sorensen
			elif (self.loss_name == "jaccard"):
				if self.dimension == 2:
					jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard',axis=(1,2))
				else:
					jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard')
				self.loss_op = 1. - jaccard
			elif (self.loss_name == "weightd_jaccard"):
				if self.dimension == 2:
					jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard',axis=(1,2), weighted=True)
				else:
					jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard', weighted=True)
				self.loss_op = 1. - jaccard
			else:
				sys.exit("Invalid loss function")

		tf.summary.scalar('loss', self.loss_op)

		print("{}: Loss function complete".format(datetime.datetime.now()))

		# argmax op
		with tf.name_scope("predicted_label"):
			self.pred_op = tf.argmax(self.logits, axis=-1 , name="prediction")

		if self.image_log:
			if self.dimension == 2:
				if 0 in self.label_classes:
					pred_log = tf.cast(self.pred_op*math.floor(255/(self.output_channel_num-1)),dtype=tf.uint8)
				else:
					pred_log = tf.cast(self.pred_op*math.floor(255/self.output_channel_num),dtype=tf.uint8)
				pred_log = tf.expand_dims(pred_log,axis=-1)
				tf.summary.image("pred", pred_log, max_outputs=self.batch_size)
			else:
				for batch in range(self.batch_size):
					if 0 in self.label_classes:
						pred_log = tf.cast(self.pred_op[batch:batch+1,:,:,:]*math.floor(255/(self.output_channel_num-1)), dtype=tf.uint8)
					else:
						pred_log = tf.cast(self.pred_op[batch:batch+1,:,:,:]*math.floor(255/(self.output_channel_num)), dtype=tf.uint8)
					
					tf.summary.image("pred", tf.transpose(pred_log,[3,1,2,0]),max_outputs=self.patch_shape[-1])

		# accuracy of the model
		with tf.name_scope("metrics"):
			correct_pred = tf.equal(tf.expand_dims(self.pred_op,-1), tf.cast(self.labels_placeholder,dtype=tf.int64))
			accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

			tf.summary.scalar('accuracy', accuracy)

			# confusion matrix
			if self.dimension == 2:
				label_one_hot = tf.one_hot(self.labels_placeholder[:,:,:,0], depth=self.output_channel_num)
				pred_one_hot = tf.one_hot(self.pred_op, depth=self.output_channel_num)
			else:
				label_one_hot = tf.one_hot(self.labels_placeholder[:,:,:,:,0],depth=self.output_channel_num)
				pred_one_hot = tf.one_hot(self.pred_op[:,:,:,:], depth=self.output_channel_num)

			for i in range(self.output_channel_num):
				if i == 0:
					continue
				else:
					if self.dimension == 2:
						tp, tp_op = tf.metrics.true_positives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="true_positives_"+str(self.label_classes[i]))
						tn, tn_op = tf.metrics.true_negatives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="true_negatives_"+str(self.label_classes[i]))
						fp, fp_op = tf.metrics.false_positives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="false_positives_"+str(self.label_classes[i]))
						fn, fn_op = tf.metrics.false_negatives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="false_negatives_"+str(self.label_classes[i]))
					else:
						tp, tp_op = tf.metrics.true_positives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="true_positives_"+str(self.label_classes[i]))
						tn, tn_op = tf.metrics.true_negatives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="true_negatives_"+str(self.label_classes[i]))
						fp, fp_op = tf.metrics.false_positives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="false_positives_"+str(self.label_classes[i]))
						fn, fn_op = tf.metrics.false_negatives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="false_negatives_"+str(self.label_classes[i]))
					sensitivity_op = tf.divide(tf.cast(tp_op,tf.float32),tf.cast(tf.add(tp_op,fn_op),tf.float32))
					specificity_op = tf.divide(tf.cast(tn_op,tf.float32),tf.cast(tf.add(tn_op,fp_op),tf.float32))
					dice_op = 2.*tp_op/(2.*tp_op+fp_op+fn_op)
				
				tf.summary.scalar('sensitivity_'+str(self.label_classes[i]), sensitivity_op)
				tf.summary.scalar('specificity_'+str(self.label_classes[i]), specificity_op)
				tf.summary.scalar('dice_'+str(self.label_classes[i]), dice_op)

		print("{}: Metrics complete".format(datetime.datetime.now()))

		print("{}: Build graph complete".format(datetime.datetime.now()))
Пример #6
0
    def __init__(self, options):
        self.opt = options
        self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name)

        self.models = {}
        self.parameters_to_train = []

        self.device = torch.device("cpu" if self.opt.no_cuda else "cuda")

        # self.models["encoder"] = networks.ResnetEncoder(
        #     self.opt.num_layers, pretrained=False)
        # self.models["encoder"].to(self.device)
        # self.parameters_to_train += list(self.models["encoder"].parameters())
        #
        # self.models["decoder"] = networks.Decoder(
        #     self.models["encoder"].num_ch_enc)
        # self.models["decoder"].to(self.device)
        # self.parameters_to_train += list(self.models["decoder"].parameters())

        self.models["unet"] = networks.UNet(n_channels=1, n_classes=4)

        self.models["unet"].to(self.device)
        self.parameters_to_train += list(self.models["unet"].parameters())

        self.model_optimizer = optim.Adam(self.parameters_to_train,
                                          self.opt.learning_rate)

        self.dataset = datasets.Retouch_dataset

        if self.opt.use_augmentation:
            self.transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),
                                                 transforms.RandomVerticalFlip(p=0.5),
                                                 #transforms.RandomRotation(degrees=(-20, 20)),
                                                 ])
        else:
            self.transform = None
        # self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.opt.ce_weighting).to(self.device),
        #                                      ignore_index=self.opt.ignore_idx)
        self.criterion = nn.CrossEntropyLoss(reduction='none')

        train_dataset = self.dataset(
            base_dir=self.opt.base_dir,
            list_dir=self.opt.list_dir,
            split='train',
            is_train=True,
            transform=self.transform)

        self.train_loader = DataLoader(
            train_dataset,
            self.opt.batch_size,
            True,
            num_workers=self.opt.num_workers,
            pin_memory=True,
            drop_last=True)

        val_dataset = self.dataset(
            base_dir=self.opt.base_dir,
            list_dir=self.opt.list_dir,
            split='val',
            is_train=False,
            transform=self.transform)

        self.val_loader = DataLoader(
            val_dataset,
            self.opt.batch_size,
            True,
            num_workers=self.opt.num_workers,
            pin_memory=True,
            drop_last=True)

        self.val_iter = iter(self.val_loader)

        num_train_samples = len(train_dataset)
        self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs

        self.writers = {}
        for mode in ["train", "val"]:
            self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
# GPU enabled
cuda = torch.cuda.is_available()

# cross-entropy loss: weighting of negative vs positive pixels and NLL loss layer
loss_weight = torch.FloatTensor([0.01, 0.99])
if cuda:
    # Obtaining log-probabilities in a neural network is easily achieved by adding a LOgSoftmax layer in the last layer
    # of your netork. You may use CrossEntropyLoss instead, if you prefer not to add an extra layer.
    loss_weight = loss_weight.cuda()

criterion = nn.NLLLoss(weight=loss_weight)

# network and optimizer
# net = networks.VNet_Xtra(dice=dice, dropout=dropout, context=context)
net = networks.UNet()
if cuda:
    net = torch.nn.DataParallel(net,
                                device_ids=list(
                                    range(torch.cuda.device_count()))).cuda()

optimizer = optim.Adam(net.parameters(), lr=lr)

# train data loader
# train = LiverDataSet(directory=train_folder, augment=augment, context=context)
train = LiverDataSet(directory=train_folder, augment=augment)
train_sampler = torch.utils.data.sampler.WeightedRandomSampler(
    weights=train.getWeights(), num_samples=num_samples)
# train_data = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, sampler=train_sampler,
#                                          num_workers=2)
train_data = torch.utils.data.DataLoader(train,
Пример #8
0
def train(args):
    # set logger
    logging_dir = args.output_dir if args.output_dir else 'train-{}'.format(
        utils.get_datetime_string())
    os.mkdir('{}'.format(logging_dir))
    logging.basicConfig(level=logging.INFO,
                        filename='{}/log.txt'.format(logging_dir),
                        format='%(asctime)s %(message)s',
                        filemode='w')

    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

    logging.info('=========== Taks {} started! ==========='.format(
        args.output_dir))
    for arg in vars(args):
        logging.info('{}: {}'.format(arg, getattr(args, arg)))
    logging.info('========================================')

    # initialize loader
    train_set = utils.SegmentationImageFolder(
        os.sep.join([args.dataroot, 'train']),
        image_folder=args.img_dir,
        segmentation_folder=args.seg_dir,
        labels=args.color_labels,
        image_size=(args.image_width, args.image_height),
        random_horizontal_flip=args.random_horizontal_flip,
        random_rotation=args.random_rotation,
        random_crop=args.random_crop,
        random_square_crop=args.random_square_crop,
        label_regr=args.regression)
    val_set = utils.SegmentationImageFolder(
        os.sep.join([args.dataroot, 'val']),
        image_folder=args.img_dir,
        segmentation_folder=args.seg_dir,
        labels=args.color_labels,
        image_size=(args.image_width, args.image_height),
        random_square_crop=args.random_square_crop,
        label_regr=args.regression)
    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=args.batch_size,
                                               shuffle=True)
    val_loader = torch.utils.data.DataLoader(val_set,
                                             batch_size=args.val_batch_size)

    # initialize model, input channels need to be calculated by hand
    n_classes = len(args.color_labels)
    if args.regression:
        model = networks.UNet(args.unet_layers, 3, 1, use_bn=args.batch_norm)
    else:
        model = networks.UNet(args.unet_layers,
                              3,
                              n_classes,
                              use_bn=args.batch_norm)
    if not args.cpu:
        model.cuda()

    criterion = nn.MSELoss() if args.regression else utils.CrossEntropyLoss2D()

    # train
    iterations = 0
    for epoch in range(args.epochs):
        model.train()
        # update lr according to lr policy
        if epoch in args.lr_policy:
            lr = args.lr_policy[epoch]
            optimizer = utils.get_optimizer(args.optimizer,
                                            model.parameters(),
                                            lr=lr,
                                            momentum=args.momentum,
                                            nesterov=args.nesterov)
            if epoch > 0:
                logging.info(
                    '| Learning Rate | Epoch: {: >3d} | Change learning rate to {}'
                    .format(epoch + 1, lr))
            else:
                logging.info(
                    '| Learning Rate | Initial learning rate: {}'.format(lr))

        # iterate all samples
        losses = utils.AverageMeter()
        for i_batch, (img, seg) in enumerate(train_loader):

            img = Variable(img)
            seg = Variable(seg)

            if not args.cpu:
                img = img.cuda()
                seg = seg.cuda()

            # compute output
            output = model(img)
            loss = criterion(output, seg)
            losses.update(loss.data[0])

            # compute gradient and do SGD step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # logging training curve
            if iterations % args.print_interval == 0:
                logging.info('| Iterations: {: >6d} '
                             '| Epoch: {: >3d}/{: >3d} '
                             '| Batch: {: >4d}/{: >4d} '
                             '| Training loss: {:.6f}'.format(
                                 iterations, epoch + 1, args.epochs, i_batch,
                                 len(train_loader) - 1, losses.avg))
                losses = utils.AverageMeter()

            # validation on all val samples
            if iterations % args.validation_interval == 0:
                model.eval()
                val_losses = utils.AverageMeter()
                gt_pixel_count = [0] * n_classes
                pred_pixel_count = [0] * n_classes
                intersection_pixel_count = [0] * n_classes
                union_pixel_count = [0] * n_classes

                for img, seg in val_loader:

                    img = Variable(img)
                    seg = Variable(seg)

                    if not args.cpu:
                        img = img.cuda()
                        seg = seg.cuda()

                    # compute output
                    output = model(img)
                    loss = criterion(output, seg)
                    val_losses.update(
                        loss.data[0],
                        float(img.size(0)) / float(args.batch_size))
                    output_numpy = output.data.numpy(
                    ) if args.cpu else output.data.cpu().numpy()
                    pred_labels = numpy.argmax(output_numpy, axis=1)
                    gt_labels = seg.data.numpy() if args.cpu else seg.data.cpu(
                    ).numpy()

                    pred_labels = pred_labels.flatten()
                    gt_labels = gt_labels.flatten()

                    for i in range(n_classes):
                        pred_pixel_count[i] += (pred_labels == i).sum()
                        gt_pixel_count[i] += (gt_labels == i).sum()
                        gt_dumb = numpy.full(gt_labels.shape,
                                             -1,
                                             dtype=numpy.int)
                        pred_dumb = numpy.full(pred_labels.shape,
                                               -2,
                                               dtype=numpy.int)
                        gt_dumb[gt_labels == i] = 0
                        pred_dumb[pred_labels == i] = 0
                        intersection_pixel_count[i] += (
                            gt_dumb == pred_dumb).sum()
                        pred_dumb[gt_labels == i] = 0
                        union_pixel_count[i] += (pred_dumb == 0).sum()

                # calculate mPA & mIOU
                mPA = 0
                mIOU = 0
                for i in range(n_classes):
                    mPA += float(intersection_pixel_count[i]) / float(
                        gt_pixel_count[i])
                    mIOU += float(intersection_pixel_count[i]) / float(
                        union_pixel_count[i])
                mPA /= float(n_classes)
                mIOU /= float(n_classes)

                logging.info('| Iterations: {: >6d} '
                             '| Epoch: {: >3d}/{: >3d} '
                             '| Average mPA: {:.4f} '
                             '| Average mIOU: {:.4f} '
                             '| Validation loss: {:.6f} '.format(
                                 iterations, epoch + 1, args.epochs, mPA, mIOU,
                                 val_losses.avg))

                model.train()

            if iterations % args.checkpoint_interval == 0 and iterations > 0:
                model_weights_path = '{}/iterations-{:0>6d}-epoch-{:0>3d}.pth'.format(
                    logging_dir, iterations, epoch + 1)
                torch.save(model.state_dict(), model_weights_path)
                logging.info(
                    '| Checkpoint | {} is saved!'.format(model_weights_path))

            iterations += 1