def main(): # width_in = 284 # height_in = 284 # width_out = 196 # height_out = 196 # PATH = './unet.pt' # x_train, y_train, x_val, y_val = get_dataset(width_in, height_in, width_out, height_out) # print(x_train.shape, y_train.shape, x_val.shape, y_val.shape) batch_size = 3 epochs = 1 epoch_lapse = 50 threshold = 0.5 learning_rate = 0.01 unet = UNet(in_channel=1, out_channel=2) if use_gpu: unet = unet.cuda() criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(unet.parameters(), lr=0.01, momentum=0.99) if sys.argv[1] == 'train': train(unet, batch_size, epochs, epoch_lapse, threshold, learning_rate, criterion, optimizer, x_train, y_train, x_val, y_val, width_out, height_out) pass else: if use_gpu: unet.load_state_dict(torch.load(PATH)) else: unet.load_state_dict(torch.load(PATH, map_location='cpu')) print(unet.eval())
def load_model(data, model_path, cuda=True): if cuda and not torch.cuda.is_available(): raise Exception("No GPU found, please run without --cuda") unet = UNet() if cuda: unet = unet.cuda() if not cuda: unet.load_state_dict( torch.load(model_path, map_location=lambda storage, loc: storage)) else: unet.load_state_dict(torch.load(model_path)) if cuda: data = Variable(data.cuda()) else: data = Variable(data) data = torch.unsqueeze(data, 0) output = unet(data) if cuda: output = output.cuda() return output
def define_G(input_nc, output_nc, ngf, norm='batch', use_dropout=False, gpu_ids=[]): netG = None use_gpu = len(gpu_ids) > 0 norm_layer = get_norm_layer(norm_type=norm) if use_gpu: assert (torch.cuda.is_available()) #netG = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, gpu_ids=gpu_ids) #netG = GeneratorUNet(in_channels=2, out_channels=1).cuda() netG = UNet(n_classes=1).cuda() if len(gpu_ids) > 0: netG.cuda(gpu_ids[0]) netG.apply(weights_init) return netG
def main(): global args net = UNet(3, 1) net.load(opt.ckpt_path) loss = Loss('soft_dice_loss') torch.cuda.set_device(0) net = net.cuda() loss = loss.cuda() if args.phase == 'train': # train dataset = NucleiDetector(opt, phase=args.phase) train_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin_memory) lr = opt.lr optimizer = torch.optim.Adam(net.parameters(), lr=lr, weight_decay=opt.weight_decay) previous_loss = None # haven't run for epoch in range(opt.epoch + 1): now_loss = train(train_loader, net, loss, epoch, optimizer, opt.model_save_freq, opt.model_save_path) if previous_loss is not None and now_loss > previous_loss: lr *= opt.lr_decay for param_group in optimizer.param_groups: param_group['lr'] = lr save_lr(net.model_name, opt.lr_save_path, lr) previous_loss = now_loss elif args.phase == 'val': # val phase dataset = NucleiDetector(opt, phase='val') val_loader = DataLoader(dataset, batch_size=opt.batch_size, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin_memory) val(val_loader, net, loss) else: # test phase dataset = NucleiDetector(opt, phase='test') test_loader = DataLoader(dataset, batch_size=1, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin_memory) test(test_loader, net, opt)
def test(args): """ Test some data from trained UNet """ image = load_test_image(args.test_image) # 1 c w h net = UNet(in_channels=3, out_channels=5) if args.cuda: net = net.cuda() image = image.cuda() print('Loading model param from {}'.format(args.model_state_dict)) net.load_state_dict(torch.load(args.model_state_dict)) net.eval() print('Predicting for {}...'.format(args.test_image)) ys_pred = net(image) # 1 ch w h colors = [] with open(args.mask_json_path, 'r', encoding='utf-8') as mask: print('Reading mask colors list from {}'.format(args.mask_json_path)) colors = json.loads(mask.read()) colors = [tuple(c) for c in colors] print('Mask colors: {}'.format(colors)) ys_pred = ys_pred.cpu().detach().numpy()[0] ys_pred[ys_pred < 0.5] = 0 ys_pred[ys_pred >= 0.5] = 1 ys_pred = ys_pred.astype(np.int) image_w = ys_pred.shape[1] image_h = ys_pred.shape[2] out_image = np.zeros((image_w, image_h, 3)) for w in range(image_w): for h in range(image_h): for ch in range(ys_pred.shape[0]): if ys_pred[ch][w][h] == 1: out_image[w][h][0] = colors[ch][0] out_image[w][h][1] = colors[ch][1] out_image[w][h][2] = colors[ch][2] out_image = out_image.astype(np.uint8) # w h c out_image = out_image.transpose((1, 0, 2)) # h w c out_image = Image.fromarray(out_image) out_image.save(args.test_save_path) print('Segmentation result has been saved to {}'.format( args.test_save_path))
batch_size = 4 lr = 0.001 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') dataset = BasicDataset(dir_img, dir_mask) n_val = int(len(dataset) * 0.1) n_train = len(dataset) - n_val train, val = random_split(dataset, [n_train, n_val]) train_loader = DataLoader(train, batch_size=batch_size, shuffle=True) val_loader = DataLoader(val, batch_size=batch_size, shuffle=False) # writer = SummaryWriter(comment=f'LR_{lr}_BS_{batch_size}_SCALE_{1}') net = UNet(n_channels=3, n_classes=classes, bilinear=True) net = net.cuda() optimizer = optim.Adam(net.parameters(), lr=lr, weight_decay=1e-8) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, 'min' if net.n_classes > 1 else 'max', patience=2) criterion = ImgWtLossSoftNLL(classes, epochs).cuda() criterion_ = nn.CrossEntropyLoss().cuda() for epoch in range(epochs): epoch_loss = 0.0 for batch in train_loader: imgs = batch['image'] # print(imgs.size()) true_masks = batch['mask'] # print(true_masks.size())
train_loader, val_loader = get_train_val_loader( opt.root_dir, batch_size=opt.batch_size, val_ratio=0.15, shuffle=True, num_workers=4, pin_memory=False) optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay) criterion = nn.BCELoss() vis = Visualizer(env=opt.env) if opt.is_cuda: model.cuda() criterion.cuda() if opt.n_gpu > 1: model = nn.DataParallel(model) run(model, train_loader, val_loader, criterion, vis) else: if opt.is_cuda: model.cuda() if opt.n_gpu > 1: model = nn.DataParallel(model) test_loader = get_test_loader(batch_size=20, shuffle=True, num_workers=opt.num_workers, pin_memory=opt.pin_memory) # load the model and run test
def train(): if not os.path.exists('train_model/'): os.makedirs('train_model/') if not os.path.exists('result/'): os.makedirs('result/') train_data, dev_data, word2id, id2word, char2id, opts = load_data( vars(args)) model = UNet(opts) if args.use_cuda: model = model.cuda() dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) if args.eval: print('load model...') model.load_state_dict(torch.load(args.model_dir)) model.eval() model.Evaluate(dev_batches, args.data_path + 'dev_eval.json', answer_file='result/' + args.model_dir.split('/')[-1] + '.answers', drop_file=args.data_path + 'drop.json', dev=args.data_path + 'dev-v2.0.json') exit() if args.load_model: print('load model...') model.load_state_dict(torch.load(args.model_dir)) model.eval() _, F1 = model.Evaluate(dev_batches, args.data_path + 'dev_eval.json', answer_file='result/' + args.model_dir.split('/')[-1] + '.answers', drop_file=args.data_path + 'drop.json', dev=args.data_path + 'dev-v2.0.json') best_score = F1 with open(args.model_dir + '_f1_scores.pkl', 'rb') as f: f1_scores = pkl.load(f) with open(args.model_dir + '_em_scores.pkl', 'rb') as f: em_scores = pkl.load(f) else: best_score = 0.0 f1_scores = [] em_scores = [] parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adamax(parameters, lr=args.lrate) lrate = args.lrate for epoch in range(1, args.epochs + 1): train_batches = get_batches(train_data, args.batch_size) dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) total_size = len(train_data) // args.batch_size model.train() for i, train_batch in enumerate(train_batches): loss = model(train_batch) model.zero_grad() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(parameters, opts['grad_clipping']) optimizer.step() model.reset_parameters() if i % 100 == 0: print( 'Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f' % (epoch, i, total_size, model.train_loss.value, lrate, best_score)) sys.stdout.flush() model.eval() exact_match_score, F1 = model.Evaluate( dev_batches, args.data_path + 'dev_eval.json', answer_file='result/' + args.model_dir.split('/')[-1] + '.answers', drop_file=args.data_path + 'drop.json', dev=args.data_path + 'dev-v2.0.json') f1_scores.append(F1) em_scores.append(exact_match_score) with open(args.model_dir + '_f1_scores.pkl', 'wb') as f: pkl.dump(f1_scores, f) with open(args.model_dir + '_em_scores.pkl', 'wb') as f: pkl.dump(em_scores, f) if best_score < F1: best_score = F1 print('saving %s ...' % args.model_dir) torch.save(model.state_dict(), args.model_dir) if epoch > 0 and epoch % args.decay_period == 0: lrate *= args.decay for param_group in optimizer.param_groups: param_group['lr'] = lrate
import random from torch.autograd import Variable import torch import SimpleITK as sitk import nrrd import numpy as np from model import UNet from metrics import dice_score, pixelwise_acc, iou, sen_score, DiceLoss from loss import CB_loss torch.manual_seed(10) # os.environ['OMP_NUM_THREADS']='1' # os.environ['CUDA_VISIBLE_DEVICES']='1' unet = UNet(n_channels=1, n_classes=1) unet = unet.cuda() print(unet) def normalize(x): # mean = np.mean(x) # std = np.std(x) # x = (x-mean)/std # x = np.max(x) # x = np.min(x) M = np.max(x) N = np.min(x) X = (x - N) / (M - N) return x
# Code testing config # num_points_fetch = 10 # train_num_pts = 5 # n_epochs = 4 train_num_pts = 4800 num_points_fetch = -1 n_epochs = 40 # Pick model train_on_gpu = True if LOSS_NUM in [5, 6]: model = UNet(n_channels=3, n_classes=4, flag=1).float() else: model = UNet(n_channels=3, n_classes=4).float() model = model.cuda() summary(model, (3, 140, 210)) # Print parameter choices print("Learning rate: " + str(LR)) print("Augmentation: " + str(AUG)) print(model_string) # Pick loss if LOSS_NUM == 0: print("Using Binary cross entropy") criterion = nn.BCELoss() elif LOSS_NUM == 1: print("Using Dice loss") criterion = dice_pytorch elif LOSS_NUM == 2:
netG.eval() p = 0 f_path = '/n/holyscratch01/wadduwage_lab/uom_bme/dataset_static_2020/20200105_synthBeads_1/tr_data_1sls/' for line in img_dir: print(line) GT_ = io.imread(f_path + str(line[0:-1]) + '_gt.png') modalities = np.zeros((32,128,128)) for i in range(0,32): modalities[i,:,:] = io.imread(f_path + str(line[0:-1]) +'_'+str(i+1) +'.png') depth = modalities.shape[2] predicted_im = np.zeros((128,128,1)) if np.min(np.array(GT_))==np.max(np.array(GT_)): print('Yes') GT = torch.from_numpy(np.divide(GT_,max_gt)) img = torch.from_numpy(np.divide(modalities,max_im)[None, :, :]).float() netG = netG.cuda() input = img.cuda() out = netG(input) out = out.cpu() out_img = out.data[0] out_img = np.squeeze(out_img) GT = np.squeeze(GT) predict_path= 'Predicted_mse/epoch_' + str(epochs) +'/' if not os.path.exists(predict_path): os.makedirs(predict_path) imsave(predict_path + '/' + str(line[0:-1]) + '_pred.png',out_img) imsave(predict_path + '/' + str(line[0:-1]) + '_gt.png',(GT)) print('mse=',torch.div(avg_mse,p)) print(avg_mse) print(avg_psnr) print(p)
class Trainer(): def __init__(self,config,trainLoader,validLoader): self.config = config self.trainLoader = trainLoader self.validLoader = validLoader self.numTrain = len(self.trainLoader.dataset) self.numValid = len(self.validLoader.dataset) self.saveModelDir = str(self.config.save_model_dir)+"/" self.bestModel = config.bestModel self.useGpu = self.config.use_gpu self.net = UNet() if(self.config.resume == True): print("LOADING SAVED MODEL") self.loadCheckpoint() else: print("INTIALIZING NEW MODEL") self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") self.net = self.net.to(self.device) self.totalEpochs = config.epochs self.optimizer = optim.Adam(self.net.parameters(), lr=5e-4) self.loss = DiceLoss() self.num_params = sum([p.data.nelement() for p in self.net.parameters()]) self.trainPaitence = config.train_paitence if not self.config.resume: # self.freezeLayers(6) summary(self.net, input_size=(3,256,256)) print('[*] Number of model parameters: {:,}'.format(self.num_params)) self.writer = SummaryWriter(self.config.tensorboard_path+"/") def train(self): bestIOU = 0 print("\n[*] Train on {} sample pairs, validate on {} trials".format( self.numTrain, self.numValid)) for epoch in range(0,self.totalEpochs): print('\nEpoch: {}/{}'.format(epoch+1, self.totalEpochs)) self.trainOneEpoch(epoch) validationIOU = self.validationTest(epoch) print("VALIDATION IOU: ",validationIOU) # check for improvement if(validationIOU > bestIOU): print("COUNT RESET !!!") bestIOU=validationIOU self.counter = 0 self.saveCheckPoint( { 'epoch': epoch + 1, 'model_state': self.net.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': bestIOU, },True) else: self.counter += 1 if self.counter > self.trainPaitence: self.saveCheckPoint( { 'epoch': epoch + 1, 'model_state': self.net.state_dict(), 'optim_state': self.optimizer.state_dict(), 'best_valid_acc': validationIOU, },False) print("[!] No improvement in a while, stopping training...") print("BEST VALIDATION IOU: ",bestIOU) return None def trainOneEpoch(self,epoch): self.net.train() train_loss = 0 total_IOU = 0 for batch_idx, (images,targets) in enumerate(self.trainLoader): images = images.to(self.device) targets = targets.to(self.device) self.optimizer.zero_grad() outputMaps = self.net(images) loss = self.loss(outputMaps,targets) loss.backward() self.optimizer.step() train_loss += loss.item() current_IOU = calc_IOU(outputMaps,targets) total_IOU += current_IOU del(images) del(targets) progress_bar(batch_idx, len(self.trainLoader), 'Loss: %.3f | IOU: %.3f' % (train_loss/(batch_idx+1), current_IOU)) self.writer.add_scalar('Train/Loss', train_loss/batch_idx+1, epoch) self.writer.add_scalar('Train/IOU', total_IOU/batch_idx+1, epoch) def validationTest(self,epoch): self.net.eval() validationLoss = [] total_IOU = [] with torch.no_grad(): for batch_idx, (images,targets) in enumerate(self.validLoader): images = images.to(self.device) targets = targets.to(self.device) outputMaps = self.net(images) loss = self.loss(outputMaps,targets) currentValidationLoss = loss.item() validationLoss.append(currentValidationLoss) current_IOU = calc_IOU(outputMaps,targets) total_IOU.append(current_IOU) # progress_bar(batch_idx, len(self.validLoader), 'Loss: %.3f | IOU: %.3f' % (currentValidationLoss), current_IOU) del(images) del(targets) meanIOU = np.mean(total_IOU) meanValidationLoss = np.mean(validationLoss) self.writer.add_scalar('Validation/Loss', meanValidationLoss, epoch) self.writer.add_scalar('Validation/IOU', meanIOU, epoch) print("VALIDATION LOSS: ",meanValidationLoss) return meanIOU def test(self,dataLoader): self.net.eval() testLoss = [] total_IOU = [] total_outputs_maps = [] total_input_images = [] with torch.no_grad(): for batch_idx, (images,targets) in enumerate(dataLoader): images = images.to(self.device) targets = targets.to(self.device) outputMaps = self.net(images) loss = self.loss(outputMaps,targets) testLoss.append(loss.item()) current_IOU = calc_IOU(outputMaps,targets) total_IOU.append(current_IOU) total_outputs_maps.append(outputMaps.cpu().detach().numpy()) # total_input_images.append(transforms.ToPILImage()(images)) total_input_images.append(images.cpu().detach().numpy()) del(images) del(targets) break meanIOU = np.mean(total_IOU) meanLoss = np.mean(testLoss) print("TEST IOU: ",meanIOU) print("TEST LOSS: ",meanLoss) return total_input_images,total_outputs_maps def saveCheckPoint(self,state,isBest): filename = "model.pth" ckpt_path = os.path.join(self.saveModelDir, filename) torch.save(state, ckpt_path) if isBest: filename = "best_model.pth" shutil.copyfile(ckpt_path, os.path.join(self.saveModelDir, filename)) def loadCheckpoint(self): print("[*] Loading model from {}".format(self.saveModelDir)) if(self.bestModel): print("LOADING BEST MODEL") filename = "best_model.pth" else: filename = "model.pth" ckpt_path = os.path.join(self.saveModelDir, filename) print(ckpt_path) if(self.useGpu==False): self.net=torch.load(ckpt_path, map_location=lambda storage, loc: storage) else: print("*"*40+" LOADING MODEL FROM GPU "+"*"*40) self.ckpt = torch.load(ckpt_path) self.net.load_state_dict(self.ckpt['model_state']) self.net.cuda()
class Trainer(object): """ """ def __init__(self): torch.set_num_threads(4) self.n_epochs = 10 self.batch_size = 1 self.patch_size = 384 self.is_augment = False self.cuda = torch.cuda.is_available() self.__build_model() def __build_model(self): self.model = UNet(1, 1, base=16) if self.cuda: self.model = self.model.cuda() def __reshapetensor(self, tensor, itype='image'): if itype == 'image': d0, d1, d2, d3, d4 = tensor.size() tensor = tensor.view(d0 * d1, d2, d3, d4) else: d0, d1, d2, d3 = tensor.size() tensor = tensor.view(d0 * d1, d2, d3) return tensor def __get_optimizer(self, **params): opt_params = { 'params': self.model.parameters(), 'lr': 1e-2, 'weight_decay': 1e-5 } self.optimizer = RAdam(**opt_params) # self.scheduler = None self.scheduler = optim.lr_scheduler.ReduceLROnPlateau(self.optimizer, 'max', factor=0.5, patience=10, verbose=True, min_lr=1e-5) def run(self, trainset, model_dir): """ """ print('=' * 100) print('Trainning model') print('=' * 100) if not os.path.exists(model_dir): os.mkdir(model_dir) model_path = os.path.join(model_dir, 'model.pth') #loss_fn = DiceLoss() loss_fn = FocalLoss2d() #loss_fn = CombineLoss({'dice':0.5, 'focal':0.5}) self.__get_optimizer() Loss = [] F1 = [] for epoch in range(self.n_epochs): for ith_batch, data in enumerate(trainset): images, labels = [d.cuda() for d in data] if self.cuda else data images = self.__reshapetensor(images, itype='image') labels = self.__reshapetensor(labels, itype='label') preds = self.model(images) loss = loss_fn(preds, labels) self.optimizer.zero_grad() loss.backward() self.optimizer.step() Loss.append(loss.item()) preds = torch.sigmoid(preds) preds[preds > 0.5] = 1 preds[preds <= 0.5] = 0 preds = preds.cpu().detach().numpy().flatten() labels = labels.cpu().detach().numpy().flatten() f1 = f1_score(labels, preds, average='binary') F1.append(f1) print('EPOCH : {}-----BATCH : {}-----LOSS : {}-----F1 : {}'. format(epoch, ith_batch, loss.item(), f1)) torch.save(self.model.state_dict(), model_path) return model_path
conversions = {29: 0, 76: 1, 150: 2, 179: 3, 226: 4, 255: 5} gray = cv2.cvtColor(label_im, cv2.COLOR_RGB2GRAY) for k in conversions.keys(): gray[gray == k] = conversions[k] # print(np.unique(gray)) return gray if __name__ == '__main__': net = UNet(model_dir_path=sys.argv[1], input_channels=3) test_model = sys.argv[2] image_path = sys.argv[3] label_path = sys.argv[4] patch = int(sys.argv[5]) net.load_state_dict(torch.load(test_model)) net.cuda(device=0) image_read = cv2.imread(image_path) label_read = cv2.imread(label_path) small_patch = patch // 4 full_i = image_read.shape[0] // small_patch full_j = image_read.shape[1] // small_patch # full_image = np.empty(shape=(small_patch*full_i+small_patch, full_j*small_patch+small_patch, 3)) # full_label = np.empty(shape=(small_patch*full_i+small_patch, full_j*small_patch+small_patch)) # full_pred = np.empty(shape=(small_patch*full_i+small_patch, full_j*small_patch+small_patch)) x, y = image_read.shape[0] // 2, image_read.shape[1] // 2 full_image = np.empty(shape=(x, y, 3)) full_label = np.empty(shape=(x, y)) full_pred = np.empty(shape=(x, y)) print(image_read.shape)
def train(): if not os.path.exists("train_model/"): os.makedirs("train_model/") if not os.path.exists("result/"): os.makedirs("result/") train_data, dev_data, word2id, id2word, char2id, opts = load_data( vars(args)) model = UNet(opts) if args.use_cuda: model = model.cuda() dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) if args.eval: print("load model...") model.load_state_dict(torch.load(args.model_dir)) model.eval() model.Evaluate( dev_batches, os.path.join(args.prepro_dir, "dev_eval.json"), answer_file="result/" + args.model_dir.split("/")[-1] + ".answers", drop_file=os.path.join(args.prepro_dir, "drop.json"), dev=args.dev_file, ) exit() if args.load_model: print("load model...") model.load_state_dict(torch.load(args.model_dir)) model.eval() _, F1 = model.Evaluate( dev_batches, os.path.join(args.prepro_dir, "dev_eval.json"), answer_file=os.path.join("result/", args.model_dir.split("/")[-1], ".answers"), drop_file=os.path.join(args.prepro_dir, "drop.json"), dev=args.dev_file, ) best_score = F1 with open(args.model_dir + "_f1_scores.pkl", "rb") as f: f1_scores = pkl.load(f) with open(args.model_dir + "_em_scores.pkl", "rb") as f: em_scores = pkl.load(f) else: best_score = 0.0 f1_scores = [] em_scores = [] parameters = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.Adamax(parameters, lr=args.lrate) lrate = args.lrate for epoch in range(1, args.epochs + 1): train_batches = get_batches(train_data, args.batch_size) dev_batches = get_batches(dev_data, args.batch_size, evaluation=True) total_size = len(train_data) // args.batch_size model.train() for i, train_batch in enumerate(train_batches): loss = model(train_batch) model.zero_grad() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(parameters, opts["grad_clipping"]) optimizer.step() model.reset_parameters() if i % 100 == 0: print( "Epoch = %d, step = %d / %d, loss = %.5f, lrate = %.5f best_score = %.3f" % (epoch, i, total_size, model.train_loss.value, lrate, best_score)) sys.stdout.flush() model.eval() exact_match_score, F1 = model.Evaluate( dev_batches, os.path.join(args.prepro_dir, "dev_eval.json"), answer_file=os.path.join("result/", args.model_dir.split("/")[-1], ".answers"), drop_file=os.path.join(args.prepro_dir, "drop.json"), dev=args.dev_file, ) f1_scores.append(F1) em_scores.append(exact_match_score) with open(args.model_dir + "_f1_scores.pkl", "wb") as f: pkl.dump(f1_scores, f) with open(args.model_dir + "_em_scores.pkl", "wb") as f: pkl.dump(em_scores, f) if best_score < F1: best_score = F1 print("saving %s ..." % args.model_dir) torch.save(model.state_dict(), args.model_dir) if epoch > 0 and epoch % args.decay_period == 0: lrate *= args.decay for param_group in optimizer.param_groups: param_group["lr"] = lrate
def train(args): """ Train UNet from datasets """ # dataset print('Reading dataset from {}...'.format(args.dataset_path)) train_dataset = SSDataset(dataset_path=args.dataset_path, is_train=True) val_dataset = SSDataset(dataset_path=args.dataset_path, is_train=False) train_dataloader = DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True) val_dataloader = DataLoader(dataset=val_dataset, batch_size=args.batch_size, shuffle=False) # mask with open(args.mask_json_path, 'w', encoding='utf-8') as mask: colors = SSDataset.all_colors mask.write(json.dumps(colors)) print('Mask colors list has been saved in {}'.format( args.mask_json_path)) # model net = UNet(in_channels=3, out_channels=5) if args.cuda: net = net.cuda() # setting lr = args.lr # 1e-3 optimizer = optim.Adam(net.parameters(), lr=lr) criterion = loss_fn # run train_losses = [] val_losses = [] print('Start training...') for epoch_idx in range(args.epochs): # train net.train() train_loss = 0 for batch_idx, batch_data in enumerate(train_dataloader): xs, ys = batch_data if args.cuda: xs = xs.cuda() ys = ys.cuda() ys_pred = net(xs) loss = criterion(ys_pred, ys) train_loss += loss optimizer.zero_grad() loss.backward() optimizer.step() # val net.eval() val_loss = 0 for batch_idx, batch_data in enumerate(val_dataloader): xs, ys = batch_data if args.cuda: xs = xs.cuda() ys = ys.cuda() ys_pred = net(xs) loss = loss_fn(ys_pred, ys) val_loss += loss train_losses.append(train_loss) val_losses.append(val_loss) print('Epoch: {}, Train total loss: {}, Val total loss: {}'.format( epoch_idx + 1, train_loss.item(), val_loss.item())) # save if (epoch_idx + 1) % args.save_epoch == 0: checkpoint_path = os.path.join( args.checkpoint_path, 'checkpoint_{}.pth'.format(epoch_idx + 1)) torch.save(net.state_dict(), checkpoint_path) print('Saved Checkpoint at Epoch {} to {}'.format( epoch_idx + 1, checkpoint_path)) # summary if args.do_save_summary: epoch_range = list(range(1, args.epochs + 1)) plt.plot(epoch_range, train_losses, 'r', label='Train loss') plt.plot(epoch_range, val_loss, 'g', label='Val loss') plt.imsave(args.summary_image) print('Summary images have been saved in {}'.format( args.summary_image)) # save net.eval() torch.save(net.state_dict(), args.model_state_dict) print('Saved state_dict in {}'.format(args.model_state_dict))
from model import UNet, DNet import data_loader from data_loader import * ############################################################## # Initialise the generator and discriminator with the UNet and # DNet architectures respectively. generator = UNet(True) discriminator = DNet() ################################################################## # Utilize GPU for performing all the calculations performed in the # forward and backward passes. Thus allocate all the generator and # discriminator variables on the default GPU device. generator.cuda() discriminator.cuda() ################################################################### # Create ADAM optimizer for the generator as well the discriminator. # Create loss criterion for calculating the L1 and adversarial loss. d_optimizer = optim.Adam(discriminator.parameters(), betas=(0.5, 0.999), lr=0.0002) g_optimizer = optim.Adam(generator.parameters(), betas=(0.5, 0.999), lr=0.0002) d_criterion = nn.BCELoss() g_criterion_1 = nn.BCELoss() g_criterion_2 = nn.L1Loss() train_() def train_():
def run_inference(args): model = UNet(topology=args.model_topology, input_channels=len(args.bands), num_classes=len(args.classes)) model.load_state_dict(torch.load(args.model_path, map_location='cpu'), strict=False) print('Log: Loaded pretrained {}'.format(args.model_path)) model.eval() if args.cuda: print('log: Using GPU') model.cuda(device=args.device) # all_districts = ["abbottabad", "battagram", "buner", "chitral", "hangu", "haripur", "karak", "kohat", "kohistan", "lower_dir", "malakand", "mansehra", # "nowshehra", "shangla", "swat", "tor_ghar", "upper_dir"] all_districts = ["abbottabad"] # years = [2014, 2016, 2017, 2018, 2019, 2020] years = [2016] # change this to do this for all the images in that directory for district in all_districts: for year in years: print("(LOG): On District: {} @ Year: {}".format(district, year)) # test_image_path = os.path.join(args.data_path, 'landsat8_4326_30_{}_region_{}.tif'.format(year, district)) test_image_path = os.path.join(args.data_path, 'landsat8_{}_region_{}.tif'.format( year, district)) #added(nauman) inference_loader, adjustment_mask = get_inference_loader( rasterized_shapefiles_path=args.rasterized_shapefiles_path, district=district, image_path=test_image_path, model_input_size=128, bands=args.bands, num_classes=len(args.classes), batch_size=args.bs, num_workers=4) # inference_loader = get_inference_loader(rasterized_shapefiles_path=args.rasterized_shapefiles_path, district=district, # image_path=test_image_path, model_input_size=128, bands=args.bands, # num_classes=len(args.classes), batch_size=args.bs, num_workers=4) # we need to fill our new generated test image generated_map = np.empty( shape=inference_loader.dataset.get_image_size()) for idx, data in enumerate(inference_loader): coordinates, test_x = data['coordinates'].tolist( ), data['input'] test_x = test_x.cuda( device=args.device) if args.cuda else test_x out_x, softmaxed = model.forward(test_x) pred = torch.argmax(softmaxed, dim=1) pred_numpy = pred.cpu().numpy().transpose(1, 2, 0) if idx % 5 == 0: print('LOG: on {} of {}'.format(idx, len(inference_loader))) for k in range(test_x.shape[0]): x, x_, y, y_ = coordinates[k] generated_map[x:x_, y:y_] = pred_numpy[:, :, k] # adjust the inferred map generated_map += 1 # to make forest pixels: 2, non-forest pixels: 1, null pixels: 0 generated_map = np.multiply(generated_map, adjustment_mask) # save generated map as png image, not numpy array forest_map_rband = np.zeros_like(generated_map) forest_map_gband = np.zeros_like(generated_map) forest_map_bband = np.zeros_like(generated_map) forest_map_gband[generated_map == FOREST_LABEL] = 255 forest_map_rband[generated_map == NON_FOREST_LABEL] = 255 forest_map_for_visualization = np.dstack( [forest_map_rband, forest_map_gband, forest_map_bband]).astype(np.uint8) save_this_map_path = os.path.join( args.dest, '{}_{}_inferred_map.png'.format(district, year)) matimg.imsave(save_this_map_path, forest_map_for_visualization) print('Saved: {} @ {}'.format(save_this_map_path, forest_map_for_visualization.shape))
# initial_epoch = 150 if initial_epoch > 0: print('resuming by loading epoch %03d' % initial_epoch) u_model.load_state_dict( torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch))) # model = torch.load(os.path.join(save_dir, 'model_%03d.pth' % initial_epoch)) model.eval() u_model.train() criterion = nn.MSELoss() if cuda: model = model.cuda() u_model = u_model.cuda() optimizer = optim.Adam(model.parameters(), lr=args.lr) scheduler = MultiStepLR(optimizer, milestones=[30, 60, 90], gamma=0.2) # learning rates for epoch in range(initial_epoch, n_epoch): scheduler.step(epoch) # step to the learning rate in this epcoh xs = dg.datagenerator(data_dir=args.train_data) xs = xs.astype('float32') / 255.0 xs = torch.from_numpy(xs.transpose( (0, 3, 1, 2))) # tensor of the clean patches, NXCXHXW DDataset = DenoisingDataset(xs, sigma) batch_y, batch_x = DDataset[:238336]
if __name__ == '__main__': args = get_args() # os.environ["CUDA_VISIBLE_DEVICES"] = '0' net = UNet(input_channels=3, nclasses=1) writer = SummaryWriter(log_dir='../../log/sn1', comment='unet') # net.cuda() # import pdb # from torchsummary import summary # summary(net, (3,1000,1000)) # pdb.set_trace() if args.gpu: if torch.cuda.device_count() > 1: net = nn.DataParallel(net) net.cuda() try: train_net(net=net, epochs=args.epochs, batch_size=args.batchsize, lr=args.lr, gpu=args.gpu, writer=writer, load=args.load) torch.save(net.state_dict(), 'model_fin.pth') except KeyboardInterrupt: torch.save(net.state_dict(), 'interrupt.pth') print('saved interrupt')
def main(): params = Params() img_dir = params.test['img_dir'] label_dir = params.test['label_dir'] save_dir = params.test['save_dir'] if not os.path.exists(save_dir): os.mkdir(save_dir) model_path = params.test['model_path'] save_flag = params.test['save_flag'] tta = params.test['tta'] params.save_params('{:s}/test_params.txt'.format(params.test['save_dir']), test=True) # check if it is needed to compute accuracies eval_flag = True if label_dir else False if eval_flag: test_results = dict() # recall, precision, F1, dice, iou, haus tumor_result = utils.AverageMeter(7) lym_result = utils.AverageMeter(7) stroma_result = utils.AverageMeter(7) all_result = utils.AverageMeter(7) conf_matrix = np.zeros((3, 3)) # data transforms test_transform = get_transforms(params.transform['test']) model_name = params.model['name'] if model_name == 'ResUNet34': model = ResUNet34(params.model['out_c'], fixed_feature=params.model['fix_params']) elif params.model['name'] == 'UNet': model = UNet(3, params.model['out_c']) else: raise NotImplementedError() model = torch.nn.DataParallel(model) model = model.cuda() cudnn.benchmark = True # ----- load trained model ----- # print("=> loading trained model") best_checkpoint = torch.load(model_path) model.load_state_dict(best_checkpoint['state_dict']) print("=> loaded model at epoch {}".format(best_checkpoint['epoch'])) model = model.module # switch to evaluate mode model.eval() counter = 0 print("=> Test begins:") img_names = os.listdir(img_dir) if save_flag: if not os.path.exists(save_dir): os.mkdir(save_dir) strs = img_dir.split('/') prob_maps_folder = '{:s}/{:s}_prob_maps'.format(save_dir, strs[-1]) seg_folder = '{:s}/{:s}_segmentation'.format(save_dir, strs[-1]) if not os.path.exists(prob_maps_folder): os.mkdir(prob_maps_folder) if not os.path.exists(seg_folder): os.mkdir(seg_folder) # img_names = ['193-adca-5'] # total_time = 0.0 for img_name in img_names: # load test image print('=> Processing image {:s}'.format(img_name)) img_path = '{:s}/{:s}'.format(img_dir, img_name) img = Image.open(img_path) ori_h = img.size[1] ori_w = img.size[0] name = os.path.splitext(img_name)[0] if eval_flag: label_path = '{:s}/{:s}_label.png'.format(label_dir, name) gt = misc.imread(label_path) input = test_transform((img, ))[0].unsqueeze(0) print('\tComputing output probability maps...') prob_maps = get_probmaps(input, model, params) if tta: img_hf = img.transpose(Image.FLIP_LEFT_RIGHT) # horizontal flip img_vf = img.transpose(Image.FLIP_TOP_BOTTOM) # vertical flip img_hvf = img_hf.transpose( Image.FLIP_TOP_BOTTOM) # horizontal and vertical flips input_hf = test_transform( (img_hf, ))[0].unsqueeze(0) # horizontal flip input input_vf = test_transform( (img_vf, ))[0].unsqueeze(0) # vertical flip input input_hvf = test_transform((img_hvf, ))[0].unsqueeze( 0) # horizontal and vertical flip input prob_maps_hf = get_probmaps(input_hf, model, params) prob_maps_vf = get_probmaps(input_vf, model, params) prob_maps_hvf = get_probmaps(input_hvf, model, params) # re flip prob_maps_hf = np.flip(prob_maps_hf, 2) prob_maps_vf = np.flip(prob_maps_vf, 1) prob_maps_hvf = np.flip(np.flip(prob_maps_hvf, 1), 2) # rotation 90 and flips img_r90 = img.rotate(90, expand=True) img_r90_hf = img_r90.transpose( Image.FLIP_LEFT_RIGHT) # horizontal flip img_r90_vf = img_r90.transpose( Image.FLIP_TOP_BOTTOM) # vertical flip img_r90_hvf = img_r90_hf.transpose( Image.FLIP_TOP_BOTTOM) # horizontal and vertical flips input_r90 = test_transform((img_r90, ))[0].unsqueeze(0) input_r90_hf = test_transform( (img_r90_hf, ))[0].unsqueeze(0) # horizontal flip input input_r90_vf = test_transform( (img_r90_vf, ))[0].unsqueeze(0) # vertical flip input input_r90_hvf = test_transform((img_r90_hvf, ))[0].unsqueeze( 0) # horizontal and vertical flip input prob_maps_r90 = get_probmaps(input_r90, model, params) prob_maps_r90_hf = get_probmaps(input_r90_hf, model, params) prob_maps_r90_vf = get_probmaps(input_r90_vf, model, params) prob_maps_r90_hvf = get_probmaps(input_r90_hvf, model, params) # re flip prob_maps_r90 = np.rot90(prob_maps_r90, k=3, axes=(1, 2)) prob_maps_r90_hf = np.rot90(np.flip(prob_maps_r90_hf, 2), k=3, axes=(1, 2)) prob_maps_r90_vf = np.rot90(np.flip(prob_maps_r90_vf, 1), k=3, axes=(1, 2)) prob_maps_r90_hvf = np.rot90(np.flip(np.flip(prob_maps_r90_hvf, 1), 2), k=3, axes=(1, 2)) # utils.show_figures((np.array(img), np.array(img_r90_hvf), # np.swapaxes(np.swapaxes(prob_maps_r90_hvf, 0, 1), 1, 2))) prob_maps = (prob_maps + prob_maps_hf + prob_maps_vf + prob_maps_hvf + prob_maps_r90 + prob_maps_r90_hf + prob_maps_r90_vf + prob_maps_r90_hvf) / 8 pred = np.argmax(prob_maps, axis=0) # prediction pred_inside = pred.copy() pred_inside[pred == 4] = 0 # set contours to background pred_nuclei_inside_labeled = measure.label(pred_inside > 0) pred_tumor_inside = pred_inside == 1 pred_lym_inside = pred_inside == 2 pred_stroma_inside = pred_inside == 3 pred_3types_inside = pred_tumor_inside + pred_lym_inside * 2 + pred_stroma_inside * 3 # find the correct class for each segmented nucleus N_nuclei = len(np.unique(pred_nuclei_inside_labeled)) N_class = len(np.unique(pred_3types_inside)) intersection = np.histogram2d(pred_nuclei_inside_labeled.flatten(), pred_3types_inside.flatten(), bins=(N_nuclei, N_class))[0] classes = np.argmax(intersection, axis=1) tumor_nuclei_indices = np.nonzero(classes == 1) lym_nuclei_indices = np.nonzero(classes == 2) stroma_nuclei_indices = np.nonzero(classes == 3) # solve the problem of one nucleus assigned with different labels pred_tumor_inside = np.isin(pred_nuclei_inside_labeled, tumor_nuclei_indices) pred_lym_inside = np.isin(pred_nuclei_inside_labeled, lym_nuclei_indices) pred_stroma_inside = np.isin(pred_nuclei_inside_labeled, stroma_nuclei_indices) # remove small objects pred_tumor_inside = morph.remove_small_objects(pred_tumor_inside, params.post['min_area']) pred_lym_inside = morph.remove_small_objects(pred_lym_inside, params.post['min_area']) pred_stroma_inside = morph.remove_small_objects( pred_stroma_inside, params.post['min_area']) # connected component labeling pred_tumor_inside_labeled = measure.label(pred_tumor_inside) pred_lym_inside_labeled = measure.label(pred_lym_inside) pred_stroma_inside_labeled = measure.label(pred_stroma_inside) pred_all_inside_labeled = pred_tumor_inside_labeled * 3 \ + (pred_lym_inside_labeled * 3 - 2) * (pred_lym_inside_labeled>0) \ + (pred_stroma_inside_labeled * 3 - 1) * (pred_stroma_inside_labeled>0) # dilation pred_tumor_labeled = morph.dilation(pred_tumor_inside_labeled, selem=morph.selem.disk( params.post['radius'])) pred_lym_labeled = morph.dilation(pred_lym_inside_labeled, selem=morph.selem.disk( params.post['radius'])) pred_stroma_labeled = morph.dilation(pred_stroma_inside_labeled, selem=morph.selem.disk( params.post['radius'])) pred_all_labeled = morph.dilation(pred_all_inside_labeled, selem=morph.selem.disk( params.post['radius'])) # utils.show_figures([pred, pred2, pred_labeled]) if eval_flag: print('\tComputing metrics...') gt_tumor = (gt % 3 == 0) * gt gt_lym = (gt % 3 == 1) * gt gt_stroma = (gt % 3 == 2) * gt tumor_detect_metrics = utils.accuracy_detection_clas( pred_tumor_labeled, gt_tumor, clas_flag=False) lym_detect_metrics = utils.accuracy_detection_clas( pred_lym_labeled, gt_lym, clas_flag=False) stroma_detect_metrics = utils.accuracy_detection_clas( pred_stroma_labeled, gt_stroma, clas_flag=False) all_detect_metrics = utils.accuracy_detection_clas( pred_all_labeled, gt, clas_flag=True) tumor_seg_metrics = utils.accuracy_object_level( pred_tumor_labeled, gt_tumor, hausdorff_flag=False) lym_seg_metrics = utils.accuracy_object_level(pred_lym_labeled, gt_lym, hausdorff_flag=False) stroma_seg_metrics = utils.accuracy_object_level( pred_stroma_labeled, gt_stroma, hausdorff_flag=False) all_seg_metrics = utils.accuracy_object_level(pred_all_labeled, gt, hausdorff_flag=True) tumor_metrics = [*tumor_detect_metrics[:-1], *tumor_seg_metrics] lym_metrics = [*lym_detect_metrics[:-1], *lym_seg_metrics] stroma_metrics = [*stroma_detect_metrics[:-1], *stroma_seg_metrics] all_metrics = [*all_detect_metrics[:-1], *all_seg_metrics] conf_matrix += np.array(all_detect_metrics[-1]) # save result for each image test_results[name] = { 'tumor': tumor_metrics, 'lym': lym_metrics, 'stroma': stroma_metrics, 'all': all_metrics } # update the average result tumor_result.update(tumor_metrics) lym_result.update(lym_metrics) stroma_result.update(stroma_metrics) all_result.update(all_metrics) # save image if save_flag: print('\tSaving image results...') misc.imsave('{:s}/{:s}_pred.png'.format(prob_maps_folder, name), pred.astype(np.uint8) * 50) misc.imsave( '{:s}/{:s}_prob_tumor.png'.format(prob_maps_folder, name), prob_maps[1, :, :]) misc.imsave( '{:s}/{:s}_prob_lym.png'.format(prob_maps_folder, name), prob_maps[2, :, :]) misc.imsave( '{:s}/{:s}_prob_stroma.png'.format(prob_maps_folder, name), prob_maps[3, :, :]) # np.save('{:s}/{:s}_prob.npy'.format(prob_maps_folder, name), prob_maps) # np.save('{:s}/{:s}_seg.npy'.format(seg_folder, name), pred_all_labeled) final_pred = Image.fromarray(pred_all_labeled.astype(np.uint16)) final_pred.save('{:s}/{:s}_seg.tiff'.format(seg_folder, name)) # save colored objects pred_colored = np.zeros((ori_h, ori_w, 3)) pred_colored_instance = np.zeros((ori_h, ori_w, 3)) pred_colored[pred_tumor_labeled > 0] = np.array([255, 0, 0]) pred_colored[pred_lym_labeled > 0] = np.array([0, 255, 0]) pred_colored[pred_stroma_labeled > 0] = np.array([0, 0, 255]) filename = '{:s}/{:s}_seg_colored_3types.png'.format( seg_folder, name) misc.imsave(filename, pred_colored) for k in range(1, pred_all_labeled.max() + 1): pred_colored_instance[pred_all_labeled == k, :] = np.array( utils.get_random_color()) filename = '{:s}/{:s}_seg_colored.png'.format(seg_folder, name) misc.imsave(filename, pred_colored_instance) # img_overlaid = utils.overlay_edges(label_img, pred_labeled2, img) # filename = '{:s}/{:s}_comparison.png'.format(seg_folder, name) # misc.imsave(filename, img_overlaid) counter += 1 if counter % 10 == 0: print('\tProcessed {:d} images'.format(counter)) # print('Time: {:4f}'.format(total_time/counter)) print('=> Processed all {:d} images'.format(counter)) if eval_flag: print( 'Average: clas_acc\trecall\tprecision\tF1\tdice\tiou\thausdorff\n' 'tumor: {t[0]:.4f}, {t[1]:.4f}, {t[2]:.4f}, {t[3]:.4f}, {t[4]:.4f}, {t[5]:.4f}, {t[6]:.4f}\n' 'lym: {l[0]:.4f}, {l[1]:.4f}, {l[2]:.4f}, {l[3]:.4f}, {l[4]:.4f}, {l[5]:.4f}, {l[6]:.4f}\n' 'stroma: {s[0]:.4f}, {s[1]:.4f}, {s[2]:.4f}, {s[3]:.4f}, {s[4]:.4f}, {s[5]:.4f}, {s[6]:.4f}\n' 'all: {a[0]:.4f}, {a[1]:.4f}, {a[2]:.4f}, {a[3]:.4f}, {a[4]:.4f}, {a[5]:.4f}, {a[6]:.4f}' .format(t=tumor_result.avg, l=lym_result.avg, s=stroma_result.avg, a=all_result.avg)) header = [ 'clas_acc', 'recall', 'precision', 'F1', 'Dice', 'IoU', 'Hausdorff' ] save_results(header, tumor_result.avg, lym_result.avg, stroma_result.avg, all_result.avg, test_results, conf_matrix, '{:s}/test_result.txt'.format(save_dir))
stage3_boxmodel.load_state_dict( torch.load(os.path.join(root_dir, stage3_model_load_dir))) logger.info('Stage3_Model loaded from {}'.format(stage3_model_load_dir)) stage3_pointmodel = UNet(n_channels=1, n_classes=5) # stage3_model_load_dir = "saved_model\stage3_unet_refine_point_mask\Bestmodel_394.pth" stage3_pointmodel.load_state_dict( torch.load(os.path.join(root_dir, stage3_model_load_dir))) logger.info('Stage3_Model loaded from {}'.format(stage3_model_load_dir)) # stage3_model = SCSE_UNet(n_channels=1, n_classes=2) # # stage3_model_load_dir = "saved_model/stage3_scseunet_refine_label/Bestmodel_82.pth" # stage3_model.load_state_dict(torch.load(os.path.join(root_dir, stage3_model_load_dir))) # logger.info('Stage3_Model loaded from {}'.format(stage3_model_load_dir)) stage1_model_whole.cuda() stage1_model_segm.cuda() stage2_model_box.cuda() stage3_boxmodel.cuda() stage3_pointmodel.cuda() cudnn.benchmark = True # faster convolutions, but more memory # pred_a = pred(s1_modelw=stage1_model_whole, # s1_models=stage1_model_segm, # s2_model=stage2_model_box, # stage3_model=stage3_model, # dataLoader=train_loader, # output_dir=train_output_path, # ) # pred_a.forward() pred_b = pred(
def main(): global params, best_iou, num_iter, tb_writer, logger, logger_results best_iou = 0 params = Params() params.save_params('{:s}/params.txt'.format(params.paths['save_dir'])) tb_writer = SummaryWriter('{:s}/tb_logs'.format(params.paths['save_dir'])) os.environ['CUDA_VISIBLE_DEVICES'] = ','.join( str(x) for x in params.train['gpu']) # set up logger logger, logger_results = setup_logging(params) # ----- create model ----- # model_name = params.model['name'] if model_name == 'ResUNet34': model = ResUNet34(params.model['out_c'], fixed_feature=params.model['fix_params']) elif params.model['name'] == 'UNet': model = UNet(3, params.model['out_c']) else: raise NotImplementedError() logger.info('Model: {:s}'.format(model_name)) # if not params.train['checkpoint']: # logger.info(model) model = nn.DataParallel(model) model = model.cuda() global vgg_model logger.info('=> Using VGG16 for perceptual loss...') vgg_model = vgg16_feat() vgg_model = nn.DataParallel(vgg_model).cuda() cudnn.benchmark = True # ----- define optimizer ----- # optimizer = torch.optim.Adam(model.parameters(), params.train['lr'], betas=(0.9, 0.99), weight_decay=params.train['weight_decay']) # ----- get pixel weights and define criterion ----- # if not params.train['weight_map']: criterion = torch.nn.NLLLoss().cuda() else: logger.info('=> Using weight maps...') criterion = torch.nn.NLLLoss(reduction='none').cuda() if params.train['beta'] > 0: logger.info('=> Using perceptual loss...') global criterion_perceptual criterion_perceptual = perceptual_loss() data_transforms = { 'train': get_transforms(params.transform['train']), 'val': get_transforms(params.transform['val']) } # ----- load data ----- # dsets = {} for x in ['train', 'val']: img_dir = '{:s}/{:s}'.format(params.paths['img_dir'], x) target_dir = '{:s}/{:s}'.format(params.paths['label_dir'], x) if params.train['weight_map']: weight_map_dir = '{:s}/{:s}'.format(params.paths['weight_map_dir'], x) dir_list = [img_dir, weight_map_dir, target_dir] postfix = ['weight.png', 'label_with_contours.png'] num_channels = [3, 1, 3] else: dir_list = [img_dir, target_dir] postfix = ['label_with_contours.png'] num_channels = [3, 3] dsets[x] = DataFolder(dir_list, postfix, num_channels, data_transforms[x]) train_loader = DataLoader(dsets['train'], batch_size=params.train['batch_size'], shuffle=True, num_workers=params.train['workers']) val_loader = DataLoader(dsets['val'], batch_size=params.train['val_batch_size'], shuffle=False, num_workers=params.train['workers']) # ----- optionally load from a checkpoint for validation or resuming training ----- # if params.train['checkpoint']: if os.path.isfile(params.train['checkpoint']): logger.info("=> loading checkpoint '{}'".format( params.train['checkpoint'])) checkpoint = torch.load(params.train['checkpoint']) params.train['start_epoch'] = checkpoint['epoch'] best_iou = checkpoint['best_iou'] model.load_state_dict(checkpoint['state_dict']) optimizer.load_state_dict(checkpoint['optimizer']) logger.info("=> loaded checkpoint '{}' (epoch {})".format( params.train['checkpoint'], checkpoint['epoch'])) else: logger.info("=> no checkpoint found at '{}'".format( params.train['checkpoint'])) # ----- training and validation ----- # num_iter = params.train['num_epochs'] * len(train_loader) # print training parameters logger.info("=> Initial learning rate: {:g}".format(params.train['lr'])) logger.info("=> Batch size: {:d}".format(params.train['batch_size'])) # logger.info("=> Number of training iterations: {:d}".format(num_iter)) logger.info("=> Training epochs: {:d}".format(params.train['num_epochs'])) logger.info("=> beta: {:.1f}".format(params.train['beta'])) for epoch in range(params.train['start_epoch'], params.train['num_epochs']): # train for one epoch or len(train_loader) iterations logger.info('Epoch: [{:d}/{:d}]'.format(epoch + 1, params.train['num_epochs'])) train_results = train(train_loader, model, optimizer, criterion, epoch) train_loss, train_loss_ce, train_loss_var, train_iou_nuclei, train_iou = train_results # evaluate on validation set with torch.no_grad(): val_results = validate(val_loader, model, criterion) val_loss, val_loss_ce, val_loss_var, val_iou_nuclei, val_iou = val_results # check if it is the best accuracy combined_iou = (val_iou_nuclei + val_iou) / 2 is_best = combined_iou > best_iou best_iou = max(combined_iou, best_iou) cp_flag = (epoch + 1) % params.train['checkpoint_freq'] == 0 save_checkpoint( { 'epoch': epoch + 1, 'state_dict': model.state_dict(), 'best_iou': best_iou, 'optimizer': optimizer.state_dict(), }, epoch, is_best, params.paths['save_dir'], cp_flag) # save the training results to txt files logger_results.info( '{:d}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}\t{:.4f}' .format(epoch + 1, train_loss, train_loss_ce, train_loss_var, train_iou_nuclei, train_iou, val_loss, val_iou_nuclei, val_iou)) # tensorboard logs tb_writer.add_scalars( 'epoch_losses', { 'train_loss': train_loss, 'train_loss_ce': train_loss_ce, 'train_loss_var': train_loss_var, 'val_loss': val_loss }, epoch) tb_writer.add_scalars( 'epoch_accuracies', { 'train_iou_nuclei': train_iou_nuclei, 'train_iou': train_iou, 'val_iou_nuclei': val_iou_nuclei, 'val_iou': val_iou }, epoch) tb_writer.close()