def train(cycle_num, dirs, path_to_net, plotter, batch_size=12, test_split=0.3, random_state=666, epochs=100, learning_rate=0.0001, momentum=0.9, num_folds=5, num_slices=155, n_classes=4): """ Applies training on the network Args: cycle_num (int): number of cycle in n-fold (num_folds) cross validation dirs (string): path to dataset subject directories path_to_net (string): path to directory where to save network plotter (callable): visdom plotter batch_size - default (int): batch size test_split - default (float): percentage of test split random_state - default (int): seed for k-fold cross validation epochs - default (int): number of epochs learning_rate - default (float): learning rate momentum - default (float): momentum num_folds - default (int): number of folds in cross validation num_slices - default (int): number of slices per volume n_classes - default (int): number of classes (regions) """ print('Setting started', flush=True) # Creating data indices # arange len of list of subject dirs indices = np.arange(len(glob.glob(dirs + '*'))) test_indices, trainset_indices = get_test_indices(indices, test_split) # kfold index generator for cv_num, (train_indices, val_indices) in enumerate( get_train_cv_indices(trainset_indices, num_folds, random_state)): # splitted the 5-fold CV in 5 jobs if cv_num != int(cycle_num): continue net = U_Net() device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') num_GPU = torch.cuda.device_count() if num_GPU > 1: print('Let us use {} GPUs!'.format(num_GPU), flush=True) net = nn.DataParallel(net) net.to(device) criterion = nn.CrossEntropyLoss() if cycle_num % 2 == 0: optimizer = optim.SGD(net.parameters(), lr=learning_rate, momentum=momentum) else: optimizer = optim.Adam(net.parameters(), lr=learning_rate) scheduler = ReduceLROnPlateau(optimizer, threshold=1e-6, patience=0) print('cv cycle number: ', cycle_num, flush=True) start = time.time() print('Start Train and Val loading', flush=True) MRIDataset_train = dataset.MRIDataset(dirs, train_indices) MRIDataset_val = dataset.MRIDataset(dirs, val_indices) datalengths = { 'train': len(MRIDataset_train), 'val': len(MRIDataset_val) } dataloaders = { 'train': get_dataloader(MRIDataset_train, batch_size, num_GPU), 'val': get_dataloader(MRIDataset_val, batch_size, num_GPU) } print('Train and Val loading took: ', time.time() - start, flush=True) # make loss and acc history for train and val separatly # Setup Metrics running_metrics_val = runningScore(n_classes) running_metrics_train = runningScore(n_classes) val_loss_meter = averageMeter() train_loss_meter = averageMeter() itr = 0 iou_best = 0. for epoch in tqdm(range(epochs), desc='Epochs'): print('Epoch: ', epoch + 1, flush=True) phase = 'train' print('Phase: ', phase, flush=True) start = time.time() # Set model to training mode net.train() # Iterate over data. for i, data in tqdm(enumerate(dataloaders[phase]), desc='Data Iteration ' + phase): if (i + 1) % 100 == 0: print('Number of Iteration [{}/{}]'.format( i + 1, int(datalengths[phase] / batch_size)), flush=True) # get the inputs inputs = data['mri_data'].to(device) GT = data['seg'].to(device) subject_slice_path = data['subject_slice_path'] # Clear all accumulated gradients optimizer.zero_grad() # Predict classes using inputs from the train set SR = net(inputs) # Compute the loss based on the predictions and # actual segmentation loss = criterion(SR, GT) # Backpropagate the loss loss.backward() # Adjust parameters according to the computed # gradients # -- weight update optimizer.step() # Trake and plot metrics and loss, and save network predictions = SR.data.max(1)[1].cpu().numpy() GT_cpu = GT.data.cpu().numpy() running_metrics_train.update(GT_cpu, predictions) train_loss_meter.update(loss.item(), n=1) if (i + 1) % 100 == 0: itr += 1 score, class_iou = running_metrics_train.get_scores() for k, v in score.items(): plotter.plot(k, 'itr', phase, k, itr, v) for k, v in class_iou.items(): print('Class {} IoU: {}'.format(k, v), flush=True) plotter.plot( str(k) + ' Class IoU', 'itr', phase, str(k) + ' Class IoU', itr, v) print('Loss Train', train_loss_meter.avg, flush=True) plotter.plot('Loss', 'itr', phase, 'Loss Train', itr, train_loss_meter.avg) print('Phase {} took {} s for whole {}set!'.format( phase, time.time() - start, phase), flush=True) # Validation Phase phase = 'val' print('Phase: ', phase, flush=True) start = time.time() # Set model to evaluation mode net.eval() start = time.time() with torch.no_grad(): # Iterate over data. for i, data in tqdm(enumerate(dataloaders[phase]), desc='Data Iteration ' + phase): if (i + 1) % 100 == 0: print('Number of Iteration [{}/{}]'.format( i + 1, int(datalengths[phase] / batch_size)), flush=True) # get the inputs inputs = data['mri_data'].to(device) GT = data['seg'].to(device) subject_slice_path = data['subject_slice_path'] # Clear all accumulated gradients optimizer.zero_grad() # Predict classes using inputs from the train set SR = net(inputs) # Compute the loss based on the predictions and # actual segmentation loss = criterion(SR, GT) # Trake and plot metrics and loss predictions = SR.data.max(1)[1].cpu().numpy() GT_cpu = GT.data.cpu().numpy() running_metrics_val.update(GT_cpu, predictions) val_loss_meter.update(loss.item(), n=1) if (i + 1) % 100 == 0: itr += 1 score, class_iou = running_metrics_val.get_scores() for k, v in score.items(): plotter.plot(k, 'itr', phase, k, itr, v) for k, v in class_iou.items(): print('Class {} IoU: {}'.format(k, v), flush=True) plotter.plot( str(k) + ' Class IoU', 'itr', phase, str(k) + ' Class IoU', itr, v) print('Loss Val', val_loss_meter.avg, flush=True) plotter.plot('Loss ', 'itr', phase, 'Loss Val', itr, val_loss_meter.avg) if (epoch + 1) % 10 == 0: if score['Mean IoU'] > iou_best: save_net(path_to_net, batch_size, epoch, cycle_num, train_indices, val_indices, test_indices, net, optimizer) iou_best = score['Mean IoU'] save_output(epoch, path_to_net, subject_slice_path, SR.data.cpu().numpy(), GT_cpu) print('Phase {} took {} s for whole {}set!'.format( phase, time.time() - start, phase), flush=True) # Call the learning rate adjustment function after every epoch scheduler.step(val_loss_meter.avg) # save network after training save_net(path_to_net, batch_size, epochs, cycle_num, train_indices, val_indices, test_indices, net, optimizer, iter_num=None)
class Solver(object): def __init__(self, config, train_loader, valid_loader, test_loader): # data loader self.train_loader = train_loader self.valid_loader = valid_loader self.test_loader = test_loader # Models self.unet = None self.optimizer = None self.img_ch = config['img_ch'] self.output_ch = config['output_ch'] self.criterion = torch.nn.BCELoss() # binary cross entropy loss # Hyper-parameters self.lr = config['lr'] self.beta1 = config['beta1'] # momentum1 in Adam self.beta2 = config['beta2'] # momentum2 in Adam # Training settings self.num_epochs = config['num_epochs'] self.num_epochs_decay = config['num_epoches_decay'] self.batch_size = config['batch_size'] # Path self.model_path = config['model_path'] self.result_path = config['result_path'] self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu') self.model_type = config['model_type'] self.t = config['t'] self.unet_path = os.path.join( self.model_path, '%s-%d-%.4f-%d.pkl' % (self.model_type, self.num_epochs, self.lr, self.num_epochs_decay)) self.best_epoch = 0 self.build_model() def build_model(self): """Build generator and discriminator.""" if self.model_type == 'U_Net': self.unet = U_Net(img_ch=1, output_ch=1) elif self.model_type == 'R2U_Net': self.unet = R2U_Net(img_ch=1, output_ch=1, t=self.t) #init_weights(self.unet, 'normal') self.optimizer = optim.Adam(list(self.unet.parameters()), self.lr, (self.beta1, self.beta2)) self.unet.to(self.device) def train(self): """Print out the network information.""" num_params = 0 for p in self.unet.parameters(): num_params += p.numel( ) # accumulate the number of mmodel parameters print("The number of parameters: {}".format(num_params)) # ====================================== Training ===========================================# # network train if os.path.isfile(self.unet_path): # Load the pretrained Encoder self.unet.load_state_dict(torch.load(self.unet_path)) print('%s is Successfully Loaded from %s' % (self.model_type, self.unet_path)) else: lr = self.lr best_unet_score = 0.0 best_epoch = 0 for epoch in range(self.num_epochs): self.unet.train(True) epoch_loss = 0 acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.train_loader): images, GT = images.to(self.device), GT.to(self.device) # forward result SR = self.unet(images) SR_probs = torch.sigmoid(SR) SR_flat = SR_probs.view(SR_probs.size(0), -1) # size(0) is batch_size GT_flat = GT.view(GT.size(0), -1) loss = self.criterion(SR_flat, GT_flat) epoch_loss += loss.item() # Backprop + optimize self.unet.zero_grad() loss.backward() self.optimizer.step() acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length = length + 1 acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length # Print the log info print( 'Epoch [%d/%d], Loss: %.4f, \n[Training] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f,' ' F1: %.4f, JS: %.4f, DC: %.4f' % (epoch + 1, self.num_epochs, epoch_loss, acc, SE, SP, PC, F1, JS, DC)) train_accuracy.append(acc) # Decay learning rate if (epoch + 1) > (self.num_epochs - self.num_epochs_decay): lr -= (self.lr / float(self.num_epochs_decay)) for param_group in self.optimizer.param_groups: param_group['lr'] = lr print('Decay learning rate to lr: {}.'.format(lr)) # ===================================== Validation ====================================# self.unet.train(False) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 for i, (images, GT) in enumerate(self.valid_loader): images, GT = images.to(self.device), GT.to(self.device) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length = length + 1 acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC print('[Validation] Acc: %.4f, SE: %.4f, SP: %.4f, PC: %.4f, ' 'F1: %.4f, JS: %.4f, DC: %.4f' % (acc, SE, SP, PC, F1, JS, DC)) validation_accuracy.append(acc) if unet_score > best_unet_score: best_unet_score = unet_score self.best_epoch = epoch best_unet = self.unet.state_dict( ) # contain best parameters for each layer print('Best %s model score : %.4f' % (self.model_type, best_unet_score)) torch.save(best_unet, self.unet_path) def test(self): self.unet.load_state_dict(torch.load(self.unet_path)) self.unet.eval() acc = 0. # Accuracy SE = 0. # Sensitivity (Recall) SP = 0. # Specificity PC = 0. # Precision F1 = 0. # F1 Score JS = 0. # Jaccard Similarity DC = 0. # Dice Coefficient length = 0 result = [] for i, (images, GT) in enumerate(self.test_loader): images = images.to(self.device) GT = GT.to(self.device) SR = torch.sigmoid(self.unet(images)) acc += get_accuracy(SR, GT) SE += get_sensitivity(SR, GT) SP += get_specificity(SR, GT) PC += get_precision(SR, GT) F1 += get_F1(SR, GT) JS += get_JS(SR, GT) DC += get_DC(SR, GT) length = length + 1 SR = SR.to('cpu') SR = SR.detach().numpy() result.extend(SR) acc = acc / length SE = SE / length SP = SP / length PC = PC / length F1 = F1 / length JS = JS / length DC = DC / length unet_score = JS + DC reconstruct_image(self, np.array(result)) f = open(os.path.join(self.result_path, 'result.csv'), 'a', encoding='utf-8', newline='') wr = csv.writer(f) wr.writerow([ self.model_type, acc, SE, SP, PC, F1, JS, DC, self.lr, self.best_epoch, self.num_epochs, self.num_epochs_decay, ]) f.close()