def test(args): if not args.model: print('Need a pretrained model!') return if not args.color_labels: print('Need to specify color labels') return resize_img = False if args.image_width is None or args.image_height is None else True # check if output dir exists output_dir = args.output_dir if args.output_dir else 'test-{}'.format( utils.get_datetime_string()) if not os.path.exists(output_dir): os.mkdir(output_dir) # load model model = networks.UNet(args.unet_layers, 3, len(args.color_labels)) model.load_state_dict(torch.load(args.model)) model = model.eval() if not args.cpu: model.cuda() # iterate all images with one by one transform = torchvision.transforms.ToTensor() for filename in [x for x in os.listdir(args.dataroot)]: filepath = os.sep.join([args.dataroot, filename]) with open(filepath, 'r') as f: img = Image.open(f) img = img.resize((args.image_width, args.image_height)) img = transform(img) img = img.view(1, *img.shape) img = Variable(img) if not args.cpu: img = img.cuda() output = model(img) _, c, h, w = output.data.shape output_numpy = output.data.numpy()[0] if args.cpu else output.data.cpu( ).numpy()[0] output_argmax = numpy.argmax(output_numpy, axis=0) out_img = numpy.zeros((h, w, 3), dtype=numpy.uint8) for i, color in enumerate(args.color_labels): out_img[output_argmax == i] = numpy.array(args.color_labels[i], dtype=numpy.uint8) out_img = Image.fromarray(out_img) seg_filepath = os.sep.join( [output_dir, filename[:filename.rfind('.')] + '.png']) out_img.save(seg_filepath) print('{} is exported!'.format(seg_filepath))
def net_generator(): gen_net=networks.UNet().to(device) optimizer=torch.optim.Adam( gen_net.parameters(),lr=0.001) # optimizer=torch.optim.SGD(direct_intrinsic_net.parameters(),lr=0.01,momentum=0.9) # -------------------------------------------------- print_network(gen_net) # -------------------------------------------------- # checkpoint = torch.load( # "/mnt/1T-5e7/mycodehtml/prac_data_s/kaggle/tgs-salt-identification-challenge/train/checkpoint.pth.tar") # gen_net.load_state_dict(checkpoint['state_dict']) # optimizer.load_state_dict(checkpoint['optimizer']) return gen_net,optimizer
def __init__(self, options): self.opt = options self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) self.models = {} self.parameters_to_train = [] self.device = torch.device("cpu" if self.opt.no_cuda else "cuda") # self.models["encoder"] = networks.ResnetEncoder( # self.opt.num_layers, pretrained=False) # self.models["encoder"].to(self.device) # self.parameters_to_train += list(self.models["encoder"].parameters()) # # self.models["decoder"] = networks.Decoder( # self.models["encoder"].num_ch_enc) # self.models["decoder"].to(self.device) # self.parameters_to_train += list(self.models["decoder"].parameters()) # Initialize the resnet50 and resnet101 model for this run model_50, input_size = self.initialize_model("resnet50", num_classes, feature_extract, use_pretrained=True) self.models["resnet50"] = model_50 self.models["resnet50"].to(self.device) model_101, input_size = self.initialize_model("resnet101", num_classes, feature_extract, use_pretrained=True) self.models["resnet101"] = model_101 self.models["resnet101"].to(self.device) # self.models["RAN"] = DeepLab_ResNet101_MSC(n_classes=21) self.models["RAN"] = RAN(in_channels=2048, out_channels=128) self.models["RAN"].to(self.device) self.models["unet"] = networks.UNet(n_channels=1, n_classes=4) self.models["unet"].to(self.device) # self.models["unet"] = networks.UNet(n_channels=1, n_classes=4) # self.models["unet"].to(self.device) # self.parameters_to_train += list(self.models["unet"].parameters()) # self.parameters_to_train += list(self.models["resnet50"].parameters()) # self.parameters_to_train += list(self.models["resnet101"].parameters()) self.parameters_to_train = nn.Parameter(rescale_transform( torch.normal(mean=0.5, std=1, size=(1, 3, 512, 512), device="cuda")), requires_grad=True) ''' w = Variable(torch.randn(3, 5), requires_grad=True) b = Variable(torch.randn(3, 5), requires_grad=True) self.parameters_to_train += w self.parameters_to_train += b ''' #self.model_optimizer = optim.SGD(self.parameters_to_train,self.opt.learning_rate,momentum=0.9,weight_decay=0.0005) self.model_optimizer = optim.Adam([self.parameters_to_train], self.opt.learning_rate) ''' self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate) ''' self.dataset = datasets.Retouch_dataset if self.opt.use_augmentation: self.transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), #transforms.RandomRotation(degrees=(-20, 20)), ]) else: self.transform = None # self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.opt.ce_weighting).to(self.device), # ignore_index=self.opt.ignore_idx) self.criterion = nn.CrossEntropyLoss(reduction='none') train_dataset = self.dataset(base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='train', is_train=True, transform=self.transform) train_dataset = self.dataset(base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='train', is_train=True, transform=self.transform) train_dataset = self.dataset(base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='train', is_train=True, transform=self.transform) self.train_loader = DataLoader(train_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) val_dataset = self.dataset(base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='val', is_train=False, transform=self.transform) self.val_loader = DataLoader(val_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) self.val_iter = iter(self.val_loader) num_train_samples = len(train_dataset) self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs self.writers = {} for mode in ["train", "val"]: self.writers[mode] = SummaryWriter( os.path.join(self.log_path, mode))
def __init__(self, options): self.opt = options self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) self.models = {} self.parameters_to_train = [] self.parameters_to_train_F = [] self.parameters_to_train_D = [] self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # self.models["encoder"] = networks.ResnetEncoder( # self.opt.num_layers, pretrained=False) # self.models["encoder"].to(self.device) # self.parameters_to_train += list(self.models["encoder"].parameters()) # # self.models["decoder"] = networks.Decoder( # self.models["encoder"].num_ch_enc) # self.models["decoder"].to(self.device) # self.parameters_to_train += list(self.models["decoder"].parameters()) # Initialize the resnet50 and resnet101 model for this run model_50 = self.initialize_model("resnet50", requires_grad=False) self.models["resnet50"] = model_50 self.models["resnet50"].to(self.device) model_101 = self.initialize_model("resnet101", requires_grad=False) self.models["resnet101"] = model_101 self.models["resnet101"].to(self.device) # self.models["RAN"] = DeepLab_ResNet101_MSC(n_classes=21) self.models["RAN"] = RAN(in_channels=512, out_channels=21) self.models["RAN"].to(self.device) self.models["unet"] = networks.UNet(n_channels=1, n_classes=4) self.models["unet"].to(self.device) self.parameters_to_train += list(self.models["unet"].parameters()) # Optimizers self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate) self.load_model() model_unet_encoder = self.initialize_model("unet_encoder", requires_grad=False) self.models["unet_encoder"] = model_unet_encoder self.models["unet_encoder"].to(self.device) self.parameters_to_train_F += list(self.models["unet_encoder"].parameters()) self.parameters_to_train_D += list(self.models["RAN"].parameters()) self.optimizer_F = optim.Adam(self.parameters_to_train_F, self.opt.learning_rate) self.optimizer_D = optim.Adam(self.parameters_to_train_D, self.opt.learning_rate) # self.models["unet_down4"] = UNet_Layer(output_layer='down4') # self.models["unet_down4"].to(self.device) # self.parameters_to_train += list(self.models["unet"].parameters()) # self.parameters_to_train += list(self.models["resnet50"].parameters()) # self.parameters_to_train += list(self.models["resnet101"].parameters()) ''' w = Variable(torch.randn(3, 5), requires_grad=True) b = Variable(torch.randn(3, 5), requires_grad=True) self.parameters_to_train += w self.parameters_to_train += b ''' ''' self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate) ''' self.dataset = datasets.Retouch_dataset self.coco_dataset = datasets.Coco_dataset self.transform = None ''' self.transform = transforms.Compose([ transforms.Normalize(mean=[0.1422], std=[0.0885]) ]) ''' ''' if self.opt.use_augmentation: self.transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), #transforms.RandomRotation(degrees=(-20, 20)), ]) else: self.transform = None ''' # self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.opt.ce_weighting).to(self.device), # ignore_index=self.opt.ignore_idx) self.criterion = nn.CrossEntropyLoss(reduction='none') self.source_dataset_AAN, self.source_dir_AAN = self.initialize_dataset_AAN("cirrus_val") self.target_dataset_AAN, self.target_dir_AAN = self.initialize_dataset_AAN("spectralis") self.source_dataloader_AAN = DataLoader( self.source_dataset_AAN, 1, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) self.target_dataloader_AAN = DataLoader( self.target_dataset_AAN, 1, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) train_dataset = self.dataset( base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='train', is_train=True, transform=self.transform) self.train_loader = DataLoader( train_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) val_dataset = self.dataset( base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='val', is_train=False, transform=self.transform) self.val_loader = DataLoader( val_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) self.val_iter = iter(self.val_loader) num_train_samples = len(train_dataset) self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs self.writers = {} for mode in ["train", "val", "AAN", "RAN"]: self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
def build_model_graph(self): print("{}: Start to build model graph...".format(datetime.datetime.now())) self.global_step_op = tf.train.get_or_create_global_step() if self.dimension == 2: input_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], self.input_channel_num) output_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], 1) elif self.dimension == 3: input_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], self.patch_shape[2], self.input_channel_num) output_batch_shape = (None, self.patch_shape[0], self.patch_shape[1], self.patch_shape[2], 1) else: sys.exit('Invalid Patch Shape (length should be 2 or 3)') self.images_placeholder, self.labels_placeholder = self.placeholder_inputs(input_batch_shape,output_batch_shape) # plot input and output images to tensorboard if self.image_log: if self.dimension == 2: for image_channel in range(self.input_channel_num): image_log = tf.cast(self.images_placeholder[:,:,:,image_channel:image_channel+1], dtype=tf.uint8) tf.summary.image(self.image_filenames[image_channel], image_log, max_outputs=self.batch_size) if 0 in self.label_classes: labels_log = tf.cast(self.labels_placeholder*math.floor(255/(self.output_channel_num-1)), dtype=tf.uint8) else: labels_log = tf.cast(self.labels_placeholder*math.floor(255/self.output_channel_num), dtype=tf.uint8) tf.summary.image("label",labels_log, max_outputs=self.batch_size) else: for batch in range(self.batch_size): for image_channel in range(self.input_channel_num): image_log = tf.cast(self.images_placeholder[batch:batch+1,:,:,:,image_channel], dtype=tf.uint8) tf.summary.image(self.image_filenames[image_channel], tf.transpose(image_log,[3,1,2,0]),max_outputs=self.patch_shape[-1]) if 0 in self.label_classes: labels_log = tf.cast(self.labels_placeholder[batch:batch+1,:,:,:,0]*math.floor(255/(self.output_channel_num-1)),dtype=tf.uint8) else: labels_log = tf.cast(self.labels_placeholder[batch:batch+1,:,:,:,0]*math.floor(255/self.output_channel_num), dtype=tf.uint8) tf.summary.image("label", tf.transpose(labels_log,[3,1,2,0]),max_outputs=self.patch_shape[-1]) # Get images and labels # create transformations to image and labels # Force input pipepline to CPU:0 to avoid operations sometimes ended up at GPU and resulting a slow down with tf.device('/cpu:0'): if self.dimension == 2: train_transforms_3d = [] train_transforms_2d = [ NiftiDataset2D.ManualNormalization(0,300), NiftiDataset2D.Resample(self.spacing), NiftiDataset2D.Padding(self.patch_shape), NiftiDataset2D.RandomCrop(self.patch_shape) ] test_transforms_3d = [] test_transforms_2d = [ NiftiDataset2D.ManualNormalization(0,300), NiftiDataset2D.Resample(self.spacing), NiftiDataset2D.Padding(self.patch_shape), NiftiDataset2D.RandomCrop(self.patch_shape) ] trainTransforms = {"3D": train_transforms_3d, "2D": train_transforms_2d} testTransforms = {"3D": test_transforms_3d, "2D": test_transforms_2d} else: trainTransforms = [ # NiftiDataset.Normalization(), # NiftiDataset3D.ExtremumNormalization(0.1), # NiftiDataset3D.ManualNormalization(0,300), NiftiDataset3D.StatisticalNormalization(2.5), NiftiDataset3D.Resample((self.spacing[0],self.spacing[1],self.spacing[2])), NiftiDataset3D.Padding((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2])), NiftiDataset3D.RandomCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),self.drop_ratio, self.min_pixel), # NiftiDataset.ConfidenceCrop((FLAGS.patch_size*3, FLAGS.patch_size*3, FLAGS.patch_layer*3),(0.0001,0.0001,0.0001)), # NiftiDataset.BSplineDeformation(randomness=2), # NiftiDataset.ConfidenceCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),(0.5,0.5,0.5)), # NiftiDataset3D.ConfidenceCrop2((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),rand_range=32,probability=0.8), # NiftiDataset3D.RandomFlip([True, False, False]), NiftiDataset3D.RandomNoise() ] # use random crop for testing testTransforms = [ # NiftiDataset.Normalization(), # NiftiDataset3D.ExtremumNormalization(0.1), # NiftiDataset3D.ManualNormalization(0,300), NiftiDataset3D.StatisticalNormalization(2.5), NiftiDataset3D.Resample((self.spacing[0],self.spacing[1],self.spacing[2])), NiftiDataset3D.Padding((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2])), NiftiDataset3D.RandomCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),self.drop_ratio, self.min_pixel) # NiftiDataset.ConfidenceCrop((FLAGS.patch_size*2, FLAGS.patch_size*2, FLAGS.patch_layer*2),(0.0001,0.0001,0.0001)), # NiftiDataset.BSplineDeformation(), # NiftiDataset.ConfidenceCrop((self.patch_shape[0], self.patch_shape[1], self.patch_shape[2]),(0.75,0.75,0.75)), # NiftiDataset.ConfidenceCrop2((FLAGS.patch_size, FLAGS.patch_size, FLAGS.patch_layer),rand_range=32,probability=0.8), # NiftiDataset.RandomFlip([True, False, False]), ] # get input and output datasets self.train_iterator = self.dataset_iterator(self.train_data_dir, trainTransforms) self.next_element_train = self.train_iterator.get_next() if self.testing: self.test_iterator = self.dataset_iterator(self.test_data_dir, testTransforms) self.next_element_test = self.test_iterator.get_next() print("{}: Dataset pipeline complete".format(datetime.datetime.now())) # network models: if self.network_name == "FCN": sys.exit("Network to be developed") elif self.network_name == "UNet": self.network = networks.UNet( num_output_channels=self.output_channel_num, dropout_rate=0.01, num_channels=4, num_levels=4, num_convolutions=2, bottom_convolutions=2, is_training=True, activation_fn="relu" ) elif self.network_name =="VNet": self.network = networks.VNet( num_classes=self.output_channel_num, dropout_rate=self.dropout_rate, num_channels=16, num_levels=4, num_convolutions=(1, 2, 3, 3), bottom_convolutions=3, is_training = True, activation_fn="prelu" ) else: sys.exit("Invalid Network") print("{}: Core network complete".format(datetime.datetime.now())) self.logits = self.network.GetNetwork(self.images_placeholder) # softmax op self.softmax_op = tf.nn.softmax(self.logits,name="softmax") if self.image_log: if self.dimension == 2: for output_channel in range(self.output_channel_num): # softmax_log = grayscale_to_rainbow(self.softmax_op[:,:,:,output_channel:output_channel+1]) softmax_log = self.softmax_op[:,:,:,output_channel:output_channel+1] softmax_log = tf.cast(softmax_log*255, dtype = tf.uint8) tf.summary.image("softmax_" + str(self.label_classes[output_channel]),softmax_log,max_outputs=self.batch_size) else: for batch in range(self.batch_size): for output_channel in range(self.output_channel_num): softmax_log = grayscale_to_rainbow(tf.transpose(self.softmax_op[batch:batch+1,:,:,:,output_channel],[3,1,2,0])) softmax_log = tf.cast(softmax_log*255,dtype=tf.uint8) tf.summary.image("softmax_" + str(self.label_classes[output_channel]),softmax_log,max_outputs=self.patch_shape[-1]) print("{}: Output layers complete".format(datetime.datetime.now())) # loss function with tf.name_scope("loss"): # """ # Tricks for faster converge: Here we provide two calculation methods, first one will ignore to classical dice formula # method 1: exclude the 0-th label in dice calculation. to use this method properly, you must set 0 as the first value in SegmentationClasses in config.json # method 2: dice will be average on all classes # """ if self.dimension == 2: labels = tf.one_hot(self.labels_placeholder[:,:,:,0], depth=self.output_channel_num) else: labels = tf.one_hot(self.labels_placeholder[:,:,:,:,0], depth=self.output_channel_num) # if 0 in self.label_classes: # ################### method 1 ################### # if self.dimension ==2: # labels = labels[:,:,:,1:] # softmax = self.softmax_op[:,:,:,1:] # else: # labels = labels[:,:,:,:,1:] # softmax = self.softmax_op[:,:,:,:,1:] # else: # ################### method 2 ################### # labels = labels # softmax = self.softmax_op labels = labels softmax = self.softmax_op if (self.loss_name == "sorensen"): if self.dimension == 2: sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen',axis=(1,2)) else: sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen') self.loss_op = 1. - sorensen elif (self.loss_name == "weighted_sorensen"): if self.dimension == 2: sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen', axis=(1,2), weighted=True) else: sorensen = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='sorensen', weighted=True) self.loss_op = 1. - sorensen elif (self.loss_name == "jaccard"): if self.dimension == 2: jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard',axis=(1,2)) else: jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard') self.loss_op = 1. - jaccard elif (self.loss_name == "weightd_jaccard"): if self.dimension == 2: jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard',axis=(1,2), weighted=True) else: jaccard = dice_coe(softmax,tf.cast(labels,dtype=tf.float32), loss_type='jaccard', weighted=True) self.loss_op = 1. - jaccard else: sys.exit("Invalid loss function") tf.summary.scalar('loss', self.loss_op) print("{}: Loss function complete".format(datetime.datetime.now())) # argmax op with tf.name_scope("predicted_label"): self.pred_op = tf.argmax(self.logits, axis=-1 , name="prediction") if self.image_log: if self.dimension == 2: if 0 in self.label_classes: pred_log = tf.cast(self.pred_op*math.floor(255/(self.output_channel_num-1)),dtype=tf.uint8) else: pred_log = tf.cast(self.pred_op*math.floor(255/self.output_channel_num),dtype=tf.uint8) pred_log = tf.expand_dims(pred_log,axis=-1) tf.summary.image("pred", pred_log, max_outputs=self.batch_size) else: for batch in range(self.batch_size): if 0 in self.label_classes: pred_log = tf.cast(self.pred_op[batch:batch+1,:,:,:]*math.floor(255/(self.output_channel_num-1)), dtype=tf.uint8) else: pred_log = tf.cast(self.pred_op[batch:batch+1,:,:,:]*math.floor(255/(self.output_channel_num)), dtype=tf.uint8) tf.summary.image("pred", tf.transpose(pred_log,[3,1,2,0]),max_outputs=self.patch_shape[-1]) # accuracy of the model with tf.name_scope("metrics"): correct_pred = tf.equal(tf.expand_dims(self.pred_op,-1), tf.cast(self.labels_placeholder,dtype=tf.int64)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) tf.summary.scalar('accuracy', accuracy) # confusion matrix if self.dimension == 2: label_one_hot = tf.one_hot(self.labels_placeholder[:,:,:,0], depth=self.output_channel_num) pred_one_hot = tf.one_hot(self.pred_op, depth=self.output_channel_num) else: label_one_hot = tf.one_hot(self.labels_placeholder[:,:,:,:,0],depth=self.output_channel_num) pred_one_hot = tf.one_hot(self.pred_op[:,:,:,:], depth=self.output_channel_num) for i in range(self.output_channel_num): if i == 0: continue else: if self.dimension == 2: tp, tp_op = tf.metrics.true_positives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="true_positives_"+str(self.label_classes[i])) tn, tn_op = tf.metrics.true_negatives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="true_negatives_"+str(self.label_classes[i])) fp, fp_op = tf.metrics.false_positives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="false_positives_"+str(self.label_classes[i])) fn, fn_op = tf.metrics.false_negatives(label_one_hot[:,:,:,i], pred_one_hot[:,:,:,i], name="false_negatives_"+str(self.label_classes[i])) else: tp, tp_op = tf.metrics.true_positives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="true_positives_"+str(self.label_classes[i])) tn, tn_op = tf.metrics.true_negatives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="true_negatives_"+str(self.label_classes[i])) fp, fp_op = tf.metrics.false_positives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="false_positives_"+str(self.label_classes[i])) fn, fn_op = tf.metrics.false_negatives(label_one_hot[:,:,:,:,i], pred_one_hot[:,:,:,:,i], name="false_negatives_"+str(self.label_classes[i])) sensitivity_op = tf.divide(tf.cast(tp_op,tf.float32),tf.cast(tf.add(tp_op,fn_op),tf.float32)) specificity_op = tf.divide(tf.cast(tn_op,tf.float32),tf.cast(tf.add(tn_op,fp_op),tf.float32)) dice_op = 2.*tp_op/(2.*tp_op+fp_op+fn_op) tf.summary.scalar('sensitivity_'+str(self.label_classes[i]), sensitivity_op) tf.summary.scalar('specificity_'+str(self.label_classes[i]), specificity_op) tf.summary.scalar('dice_'+str(self.label_classes[i]), dice_op) print("{}: Metrics complete".format(datetime.datetime.now())) print("{}: Build graph complete".format(datetime.datetime.now()))
def __init__(self, options): self.opt = options self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) self.models = {} self.parameters_to_train = [] self.device = torch.device("cpu" if self.opt.no_cuda else "cuda") # self.models["encoder"] = networks.ResnetEncoder( # self.opt.num_layers, pretrained=False) # self.models["encoder"].to(self.device) # self.parameters_to_train += list(self.models["encoder"].parameters()) # # self.models["decoder"] = networks.Decoder( # self.models["encoder"].num_ch_enc) # self.models["decoder"].to(self.device) # self.parameters_to_train += list(self.models["decoder"].parameters()) self.models["unet"] = networks.UNet(n_channels=1, n_classes=4) self.models["unet"].to(self.device) self.parameters_to_train += list(self.models["unet"].parameters()) self.model_optimizer = optim.Adam(self.parameters_to_train, self.opt.learning_rate) self.dataset = datasets.Retouch_dataset if self.opt.use_augmentation: self.transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), #transforms.RandomRotation(degrees=(-20, 20)), ]) else: self.transform = None # self.criterion = nn.CrossEntropyLoss(weight=torch.FloatTensor(self.opt.ce_weighting).to(self.device), # ignore_index=self.opt.ignore_idx) self.criterion = nn.CrossEntropyLoss(reduction='none') train_dataset = self.dataset( base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='train', is_train=True, transform=self.transform) self.train_loader = DataLoader( train_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) val_dataset = self.dataset( base_dir=self.opt.base_dir, list_dir=self.opt.list_dir, split='val', is_train=False, transform=self.transform) self.val_loader = DataLoader( val_dataset, self.opt.batch_size, True, num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) self.val_iter = iter(self.val_loader) num_train_samples = len(train_dataset) self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs self.writers = {} for mode in ["train", "val"]: self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode))
# GPU enabled cuda = torch.cuda.is_available() # cross-entropy loss: weighting of negative vs positive pixels and NLL loss layer loss_weight = torch.FloatTensor([0.01, 0.99]) if cuda: # Obtaining log-probabilities in a neural network is easily achieved by adding a LOgSoftmax layer in the last layer # of your netork. You may use CrossEntropyLoss instead, if you prefer not to add an extra layer. loss_weight = loss_weight.cuda() criterion = nn.NLLLoss(weight=loss_weight) # network and optimizer # net = networks.VNet_Xtra(dice=dice, dropout=dropout, context=context) net = networks.UNet() if cuda: net = torch.nn.DataParallel(net, device_ids=list( range(torch.cuda.device_count()))).cuda() optimizer = optim.Adam(net.parameters(), lr=lr) # train data loader # train = LiverDataSet(directory=train_folder, augment=augment, context=context) train = LiverDataSet(directory=train_folder, augment=augment) train_sampler = torch.utils.data.sampler.WeightedRandomSampler( weights=train.getWeights(), num_samples=num_samples) # train_data = torch.utils.data.DataLoader(train, batch_size=batch_size, shuffle=True, sampler=train_sampler, # num_workers=2) train_data = torch.utils.data.DataLoader(train,
def train(args): # set logger logging_dir = args.output_dir if args.output_dir else 'train-{}'.format( utils.get_datetime_string()) os.mkdir('{}'.format(logging_dir)) logging.basicConfig(level=logging.INFO, filename='{}/log.txt'.format(logging_dir), format='%(asctime)s %(message)s', filemode='w') console = logging.StreamHandler() console.setLevel(logging.INFO) formatter = logging.Formatter('%(asctime)s %(message)s') console.setFormatter(formatter) logging.getLogger('').addHandler(console) logging.info('=========== Taks {} started! ==========='.format( args.output_dir)) for arg in vars(args): logging.info('{}: {}'.format(arg, getattr(args, arg))) logging.info('========================================') # initialize loader train_set = utils.SegmentationImageFolder( os.sep.join([args.dataroot, 'train']), image_folder=args.img_dir, segmentation_folder=args.seg_dir, labels=args.color_labels, image_size=(args.image_width, args.image_height), random_horizontal_flip=args.random_horizontal_flip, random_rotation=args.random_rotation, random_crop=args.random_crop, random_square_crop=args.random_square_crop, label_regr=args.regression) val_set = utils.SegmentationImageFolder( os.sep.join([args.dataroot, 'val']), image_folder=args.img_dir, segmentation_folder=args.seg_dir, labels=args.color_labels, image_size=(args.image_width, args.image_height), random_square_crop=args.random_square_crop, label_regr=args.regression) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True) val_loader = torch.utils.data.DataLoader(val_set, batch_size=args.val_batch_size) # initialize model, input channels need to be calculated by hand n_classes = len(args.color_labels) if args.regression: model = networks.UNet(args.unet_layers, 3, 1, use_bn=args.batch_norm) else: model = networks.UNet(args.unet_layers, 3, n_classes, use_bn=args.batch_norm) if not args.cpu: model.cuda() criterion = nn.MSELoss() if args.regression else utils.CrossEntropyLoss2D() # train iterations = 0 for epoch in range(args.epochs): model.train() # update lr according to lr policy if epoch in args.lr_policy: lr = args.lr_policy[epoch] optimizer = utils.get_optimizer(args.optimizer, model.parameters(), lr=lr, momentum=args.momentum, nesterov=args.nesterov) if epoch > 0: logging.info( '| Learning Rate | Epoch: {: >3d} | Change learning rate to {}' .format(epoch + 1, lr)) else: logging.info( '| Learning Rate | Initial learning rate: {}'.format(lr)) # iterate all samples losses = utils.AverageMeter() for i_batch, (img, seg) in enumerate(train_loader): img = Variable(img) seg = Variable(seg) if not args.cpu: img = img.cuda() seg = seg.cuda() # compute output output = model(img) loss = criterion(output, seg) losses.update(loss.data[0]) # compute gradient and do SGD step optimizer.zero_grad() loss.backward() optimizer.step() # logging training curve if iterations % args.print_interval == 0: logging.info('| Iterations: {: >6d} ' '| Epoch: {: >3d}/{: >3d} ' '| Batch: {: >4d}/{: >4d} ' '| Training loss: {:.6f}'.format( iterations, epoch + 1, args.epochs, i_batch, len(train_loader) - 1, losses.avg)) losses = utils.AverageMeter() # validation on all val samples if iterations % args.validation_interval == 0: model.eval() val_losses = utils.AverageMeter() gt_pixel_count = [0] * n_classes pred_pixel_count = [0] * n_classes intersection_pixel_count = [0] * n_classes union_pixel_count = [0] * n_classes for img, seg in val_loader: img = Variable(img) seg = Variable(seg) if not args.cpu: img = img.cuda() seg = seg.cuda() # compute output output = model(img) loss = criterion(output, seg) val_losses.update( loss.data[0], float(img.size(0)) / float(args.batch_size)) output_numpy = output.data.numpy( ) if args.cpu else output.data.cpu().numpy() pred_labels = numpy.argmax(output_numpy, axis=1) gt_labels = seg.data.numpy() if args.cpu else seg.data.cpu( ).numpy() pred_labels = pred_labels.flatten() gt_labels = gt_labels.flatten() for i in range(n_classes): pred_pixel_count[i] += (pred_labels == i).sum() gt_pixel_count[i] += (gt_labels == i).sum() gt_dumb = numpy.full(gt_labels.shape, -1, dtype=numpy.int) pred_dumb = numpy.full(pred_labels.shape, -2, dtype=numpy.int) gt_dumb[gt_labels == i] = 0 pred_dumb[pred_labels == i] = 0 intersection_pixel_count[i] += ( gt_dumb == pred_dumb).sum() pred_dumb[gt_labels == i] = 0 union_pixel_count[i] += (pred_dumb == 0).sum() # calculate mPA & mIOU mPA = 0 mIOU = 0 for i in range(n_classes): mPA += float(intersection_pixel_count[i]) / float( gt_pixel_count[i]) mIOU += float(intersection_pixel_count[i]) / float( union_pixel_count[i]) mPA /= float(n_classes) mIOU /= float(n_classes) logging.info('| Iterations: {: >6d} ' '| Epoch: {: >3d}/{: >3d} ' '| Average mPA: {:.4f} ' '| Average mIOU: {:.4f} ' '| Validation loss: {:.6f} '.format( iterations, epoch + 1, args.epochs, mPA, mIOU, val_losses.avg)) model.train() if iterations % args.checkpoint_interval == 0 and iterations > 0: model_weights_path = '{}/iterations-{:0>6d}-epoch-{:0>3d}.pth'.format( logging_dir, iterations, epoch + 1) torch.save(model.state_dict(), model_weights_path) logging.info( '| Checkpoint | {} is saved!'.format(model_weights_path)) iterations += 1