def main(): cudnn.benchmark = True model = DenseNet121(N_CLASSES).cuda() ## getting the images list to be used for heatmap generation # image_names = [] # with open(TEST_IMAGE_LIST, 'r') as f: # for line in f: # items = line.split() # image_names.append(items[0]) ############################################################## if os.path.isfile(CKPT_PATH): print("=> loading checkpoint") checkpoint = torch.load(CKPT_PATH) pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' ) state_dict = checkpoint['state_dict'] for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] for key in list(state_dict.keys()): split_key = key.split('.') new_key = '.'.join(split_key[1:]) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) print("=> loaded checkpoint") else: print("=> no checkpoint found") normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) test_dataset = ChestXrayDataSet( data_dir=DATA_DIR, image_list_file=TEST_IMAGE_LIST, transform=transforms.Compose([ transforms.Resize(256), transforms.TenCrop(224), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda( lambda crops: torch.stack([normalize(crop) for crop in crops])) ])) test_loader = sy.FederatedDataLoader( federated_dataset=test_dataset.federate(tuple(workers)), batch_size=BATCH_SIZE, shuffle=False) # initialize the ground truth and output tensor gt = torch.FloatTensor() gt = gt.cuda() pred = torch.FloatTensor() pred = pred.cuda() # switch to evaluate mode model.eval() model_nosyft = model for i, (inp, target) in enumerate(test_loader): location = inp.location target = target.get().cuda() gt = torch.cat((gt, target), 0) inp = inp.get() bs, n_crops, c, h, w = inp.size() input_var = inp.view(-1, c, h, w).cuda().send(location) model.send(location) output = model(input_var) output_mean = output.view(bs, n_crops, -1).mean(1) output_mean_data = output_mean.get() model.get() pred = torch.cat((pred, output_mean_data.data), 0) pred_np = pred.cpu().numpy() pred_np[pred_np < THRESHOLD] = 0 pred_np[pred_np >= THRESHOLD] = 1 indexes = [] AUROCs, specificity, sensitivity, F1score = compute_AUCs(gt, pred) AUROC_avg = np.array(AUROCs).mean() print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg)) for i in range(N_CLASSES): print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i])) print('The specificity of {} is {}'.format(CLASS_NAMES[i], specificity[i])) print('The sensitivity of {} is {}'.format(CLASS_NAMES[i], sensitivity[i])) print('The F1 score of {} is {}'.format(CLASS_NAMES[i], F1score[i])) for arr in pred_np: indexes.append([index for index, val in enumerate(arr) if val == 1]) for index, disease_array in enumerate(indexes): if len(disease_array) > 0: print( f'XRAY :: {index + 1} :: {[CLASS_NAMES[i] for i in disease_array]}' ) else: print(f'XRAY :: {index + 1} :: No Disease detected') # print("Confidence of the 14 classes are {}".format(pred[i] * 100)) print("Confidence of the 14 classes are {}".format({ CLASS_NAMES[j]: pred[index][j] * 100 for j in range(len(pred[index])) })) get_heat_map(model_nosyft, image_names[index], DATA_DIR)
def main(): cudnn.benchmark = True # initialize and load the model model = DenseNet121(N_CLASSES).cuda() # model = torch.nn.DataParallel(model).cuda() ## getting the images list to be used for heatmap generation image_names = [] with open(TEST_IMAGE_LIST, 'r') as f: for line in f: items = line.split() image_names.append(items[0]) ############################################################## if os.path.isfile(CKPT_PATH): print("=> loading checkpoint") checkpoint = torch.load(CKPT_PATH) pattern = re.compile( r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$' ) state_dict = checkpoint['state_dict'] for key in list(state_dict.keys()): res = pattern.match(key) if res: new_key = res.group(1) + res.group(2) state_dict[new_key] = state_dict[key] del state_dict[key] for key in list(state_dict.keys()): split_key = key.split('.') new_key = '.'.join(split_key[1:]) state_dict[new_key] = state_dict[key] del state_dict[key] model.load_state_dict(state_dict) print("=> loaded checkpoint") else: print("=> no checkpoint found") normalize = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) test_dataset = ChestXrayDataSet( data_dir=DATA_DIR, image_list_file=TEST_IMAGE_LIST, transform=transforms.Compose([ transforms.Resize(256), transforms.TenCrop(224), transforms.Lambda(lambda crops: torch.stack( [transforms.ToTensor()(crop) for crop in crops])), transforms.Lambda( lambda crops: torch.stack([normalize(crop) for crop in crops])) ])) test_loader = sy.FederatedDataLoader( federated_dataset=test_dataset.federate((bob, alice)), batch_size=BATCH_SIZE, shuffle=False) # initialize the ground truth and output tensor gt = torch.FloatTensor() gt = gt.cuda() pred = torch.FloatTensor() pred = pred.cuda() # switch to evaluate mode model.eval() model_nosyft = model for i, (inp, target) in enumerate(test_loader): location = inp.location inp = inp.get() bs, n_crops, c, h, w = inp.size() # print(inp.size()) # input_var = torch.autograd.Variable(torch.FloatTensor(inp.view(-1, c, h, w)).cuda(), volatile=True).send(location) input_var = inp.view(-1, c, h, w).cuda().send(location) model.send(location) output = model(input_var) # print(f'output of the inference is {output.shape}') output_mean = output.view(bs, n_crops, -1).mean(1) output_mean_data = output_mean.get() model.get() pred = torch.cat((pred, output_mean_data.data), 0) pred_np = pred.cpu().numpy() pred_np[pred_np < THRESHOLD] = 0 pred_np[pred_np >= THRESHOLD] = 1 indexes = [] for arr in pred_np: indexes.append([index for index, val in enumerate(arr) if val == 1]) for index, disease_array in enumerate(indexes): if len(disease_array) > 0: print( f'XRAY :: {index + 1} :: {[CLASS_NAMES[i] for i in disease_array]}' ) else: print(f'XRAY :: {index + 1} :: No Disease detected') get_heat_map(model_nosyft, image_names[index], DATA_DIR)