def evaluate_segmentation(net_segmentation): net_segmentation.eval() hist = np.zeros((nClasses, nClasses)) val_seg_loader = torch.utils.data.DataLoader(segmentation_data_loader( img_root=val_img_root, gt_root=val_gt_root, image_list=val_image_list, suffix=dataset, out=out, crop=False, mirror=False), batch_size=1, num_workers=8, shuffle=False) progbar = tqdm(total=len(val_seg_loader), desc='Eval') hist = np.zeros((nClasses, nClasses)) for batch_idx, (inputs_, targets) in enumerate(val_seg_loader): inputs_, targets = Variable(inputs_.to(device)), Variable( targets.to(device)) outputs = net_segmentation(inputs_) _, predicted = torch.max(outputs.data, 1) correctLabel = targets.view(-1, targets.size()[1], targets.size()[2]) hist += fast_hist( correctLabel.view(correctLabel.size(0), -1).cpu().numpy(), predicted.view(predicted.size(0), -1).cpu().numpy(), nClasses) miou, p_acc, fwacc = performMetrics(hist) progbar.set_description('Eval (mIoU=%.4f)' % (miou)) progbar.update(1) miou, p_acc, fwacc = performMetrics(hist) print('\n mIoU: ', miou) print('\n Pixel accuracy: ', p_acc) print('\n Frequency Weighted Pixel accuracy: ', fwacc)
def visualize_segmentation(net_segmentation): val_seg_loader = torch.utils.data.DataLoader(segmentation_data_loader( img_root=val_img_root, gt_root=val_gt_root, image_list=val_image_list, suffix=dataset, out=out, crop=False, mirror=False), batch_size=1, num_workers=8, shuffle=False) fig, axs = plt.subplots(nrows=4, ncols=3, figsize=(9, 9)) for batch_idx, (inputs_, targets) in enumerate(val_seg_loader): inputs_, targets = Variable(inputs_.to(device)), Variable( targets.to(device)) outputs = net_segmentation(inputs_) _, predicted = torch.max(outputs.data, 1) input_ = np.asarray(inputs_[0].cpu().numpy().transpose(1, 2, 0) + mean_bgr[np.newaxis, np.newaxis, :], dtype=np.uint8)[:, :, ::-1] axs[batch_idx, 0].imshow(input_) axs[batch_idx, 1].imshow(apply_color_map(targets[0].cpu().data, c_map)) axs[batch_idx, 2].imshow(apply_color_map(predicted[0].cpu().data, c_map)) if batch_idx == 3: break axs[0, 0].set_title('input', fontsize=18) axs[0, 1].set_title('GT', fontsize=18) axs[0, 2].set_title('Pred', fontsize=18) fig.tight_layout() plt.show()
from loss import soft_iou from metric import fast_hist, performMetrics from utils.dataloaders import segmentation_data_loader train_seg_loss = [] val_seg_loss = [] train_seg_iou = [] val_seg_iou = [] ITER_SIZE = 2 ### accumulate gradients over ITER_SIZE iterations best_iou = 0. train_seg_loader = torch.utils.data.DataLoader(segmentation_data_loader( img_root=train_img_root, gt_root=train_gt_root, image_list=train_image_list_path + supervised_split + '.txt', suffix=dataset, out=out, crop=True, crop_shape=[256, 256], mirror=True), batch_size=32, num_workers=8, shuffle=True) val_seg_loader = torch.utils.data.DataLoader(segmentation_data_loader( img_root=val_img_root, gt_root=val_gt_root, image_list=val_image_list, suffix=dataset, out=out, crop=False,