test_image_dir = '../test/images/' test_label_dir = '../test/labels/' checkpoints_dir = '../checkpoints/' # save_path = 'test_results/' # if not exists(save_path): # os.mkdir(save_path) net = UNet(n_channels=3, n_classes=1) net.cuda() net.eval() for checkpoint in range(1, 31): net.load_state_dict( torch.load(checkpoints_dir + 'CP' + str(5 * checkpoint - 4) + '.pth')) transform1 = transforms.Compose([ToTensor()]) test_dataset = Dataset_unet(test_image_dir, test_label_dir, transform=transform1) dataloader = DataLoader(test_dataset, batch_size=batch_size) dataset_sizes = len(test_dataset) batch_num = int(dataset_sizes / batch_size) Sensitivity = 0 Specificity = 0 Precision = 0 F1 = 0 F2 = 0 ACC_overall = 0 IoU_poly = 0 IoU_bg = 0
def train_net(image_dir, label_dir, boundary_dir, checkpoint_dir, net, epochs=300, batch_size=4, lr=0.0001, save_cp=True, gpu=True): print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Checkpoints: {} Want CUDA: {} '''.format(epochs, batch_size, lr, str(save_cp), str(gpu))) optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[50, 100], gamma=0.1) criterion_bcedice = LossFunction_yq.BCEDiceLoss() criterion_bce = LossFunction_yq.BCELoss() transform1 = transforms.Compose([ToTensor()]) dataset = Dataset_unet(image_dir, label_dir, boundary_dir, transform=transform1) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size, drop_last=True) dataset_sizes = len(dataset) batch_num = int(dataset_sizes / batch_size) for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() lr_scheduler.step() epoch_loss = 0 for i_batch, sample_batched in enumerate(dataloader): # total_steps = epoch * batch_num + i_batch optimizer.zero_grad() train_image = sample_batched['image'] train_label = sample_batched['label'] train_boundary = sample_batched['boundary'] if torch.cuda.is_available() and gpu: train_image = train_image.cuda() train_label = train_label.cuda() train_boundary = train_boundary.cuda() pred_area, pred_bd, pred_bd_cons = net(train_image) pred_area_probs = torch.sigmoid(pred_area) pred_bd_probs = torch.sigmoid(pred_bd) pred_bd_cons_probs = torch.sigmoid(pred_bd_cons) # loss area loss_area = criterion_bcedice(pred_area_probs, train_label) # print('loss_area', loss_area) # loss bd loss_bd = criterion_bce(pred_bd_probs, train_boundary) # print('loss_bd', loss_bd) # loss bd constraint 1 loss_bd_cons1 = criterion_bce(pred_bd_cons_probs, train_boundary) # print('loss_bd_cons1', loss_bd_cons1) # loss bd constraint 2 pred_bd_probs_cp = pred_bd_probs.clone().detach().requires_grad_( False) loss_bd_cons2 = criterion_bce(pred_bd_cons_probs, pred_bd_probs_cp) # print('loss_bd_cons2', loss_bd_cons2) # total_loss loss = loss_area + loss_bd + loss_bd_cons1 + 0.5 * loss_bd_cons2 loss.backward() optimizer.step() epoch_loss += loss.item() print('Epoch finished ! Train Loss: {}'.format(epoch_loss / batch_num)) writer.add_scalar('Train_Loss', epoch_loss / batch_num, epoch) if save_cp: if epoch % 5 == 0: torch.save(net.state_dict(), checkpoint_dir + 'CP{}.pth'.format(epoch + 1))
def predict(validate_image_dir, validate_label_dir, validate_boundary_dir, checkpoints_dir, net, batch_size=1, gpu=True): transform1 = transforms.Compose([ToTensor()]) file_num = 41 validate_dataset = Dataset_unet(validate_image_dir, validate_label_dir, validate_boundary_dir, transform=transform1) dataloader = DataLoader(validate_dataset, batch_size=batch_size) dataset_sizes = len(validate_dataset) batch_num = int(dataset_sizes / batch_size) for epoch in range(1, file_num): net.load_state_dict( torch.load(checkpoints_dir + 'CP' + str(5 * epoch - 4) + '.pth')) Sensitivity = 0 Specificity = 0 Precision = 0 F1 = 0 F2 = 0 ACC_overall = 0 IoU_poly = 0 IoU_bg = 0 IoU_mean = 0 for i_batch, sample_batched in enumerate(dataloader): validate_image = sample_batched['image'] validate_label = sample_batched['label'] if torch.cuda.is_available() and gpu: validate_image = validate_image.cuda() validate_label = validate_label.cuda() predict_label, _ = net(validate_image) predict_probs = torch.sigmoid(predict_label) predict_probs_rep = predict_probs predict_probs_rep = (predict_probs_rep >= 0.5).float() # validate_label_rep = validate_label validate_label_rep = (validate_label_rep >= 0.5).float() label_probs_rep_inverse = predict_probs_rep label_probs_rep_inverse = (label_probs_rep_inverse == 0).float() train_label_rep_inverse = validate_label_rep train_label_rep_inverse = (train_label_rep_inverse == 0).float() # calculate TP, FP, TN, FN TP = predict_probs_rep.mul(validate_label_rep).sum() FP = predict_probs_rep.mul(train_label_rep_inverse).sum() TN = label_probs_rep_inverse.mul(train_label_rep_inverse).sum() FN = label_probs_rep_inverse.mul(validate_label_rep).sum() if TP.item() == 0: # print('TP=0 now!') # print('Epoch: {}'.format(epoch)) # print('i_batch: {}'.format(i_batch)) TP = torch.Tensor([1]).cuda() # Sensitivity, hit rate, recall, or true positive rate temp_Sensitivity = TP / (TP + FN) # Specificity or true negative rate temp_Specificity = TN / (TN + FP) # Precision or positive predictive value temp_Precision = TP / (TP + FP) # F1 score = Dice temp_F1 = 2 * temp_Precision * temp_Sensitivity / ( temp_Precision + temp_Sensitivity) # F2 score temp_F2 = 5 * temp_Precision * temp_Sensitivity / ( 4 * temp_Precision + temp_Sensitivity) # Overall accuracy temp_ACC_overall = (TP + TN) / (TP + FP + FN + TN) # Mean accuracy # temp_ACC_mean = TP / pixels # IoU for poly temp_IoU_poly = TP / (TP + FP + FN) # IoU for background temp_IoU_bg = TN / (TN + FP + FN) # mean IoU temp_IoU_mean = (temp_IoU_poly + temp_IoU_bg) / 2.0 # To Sum Sensitivity += temp_Sensitivity.item() Specificity += temp_Specificity.item() Precision += temp_Precision.item() F1 += temp_F1.item() F2 += temp_F2.item() ACC_overall += temp_ACC_overall.item() IoU_poly += temp_IoU_poly.item() IoU_bg += temp_IoU_bg.item() IoU_mean += temp_IoU_mean.item() writer.add_scalar('Validate/sensitivity', Sensitivity / batch_num, epoch) writer.add_scalar('Validate/specificity', Specificity / batch_num, epoch) writer.add_scalar('Validate/precision', Precision / batch_num, epoch) writer.add_scalar('Validate/F1', F1 / batch_num, epoch) writer.add_scalar('Validate/F2', F2 / batch_num, epoch) writer.add_scalar('Validate/ACC_overall', ACC_overall / batch_num, epoch) writer.add_scalar('Validate/IoU_poly', IoU_poly / batch_num, epoch) writer.add_scalar('Validate/IoU_bg', IoU_bg / batch_num, epoch) writer.add_scalar('Validate/IoU_mean', IoU_mean / batch_num, epoch)
def train_net(image_dir, label_dir, checkpoint_dir, net, epochs=150, batch_size=4, lr=0.01, save_cp=True, gpu=True): print(''' Starting training: Epochs: {} Batch size: {} Learning rate: {} Checkpoints: {} Want CUDA: {} '''.format(epochs, batch_size, lr, str(save_cp), str(gpu))) optimizer = optim.SGD(net.parameters(), lr=lr, momentum=0.9, weight_decay=0.0005) # lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1) # decay LR: e.g. gamma = 0.1 lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20, 100, 125], gamma=0.1) criterion = LossFunction_yq.BCEDiceLoss() transform1 = transforms.Compose([ToTensor()]) dataset = Dataset_unet(image_dir, label_dir, transform=transform1) dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=batch_size, drop_last=True) dataset_sizes = len(dataset) batch_num = int(dataset_sizes / batch_size) for epoch in range(epochs): print('Starting epoch {}/{}.'.format(epoch + 1, epochs)) net.train() lr_scheduler.step() epoch_loss = 0 for i_batch, sample_batched in enumerate(dataloader): # total_steps = epoch * batch_num + i_batch optimizer.zero_grad() train_image = sample_batched['image'] train_label = sample_batched['label'] if torch.cuda.is_available() and gpu: train_image = train_image.cuda() train_label = train_label.cuda() label_pred = net(train_image) label_probs = torch.sigmoid(label_pred) loss = criterion(label_probs, train_label) loss.backward() optimizer.step() epoch_loss += loss.item() print('Epoch finished ! Train Loss: {}'.format(epoch_loss / batch_num)) writer.add_scalar('Train_Loss', epoch_loss / batch_num, epoch) if save_cp: if epoch % 5 == 0: torch.save(net.state_dict(), checkpoint_dir + 'CP{}.pth'.format(epoch + 1))