Exemplo n.º 1
0
def main():

    cudnn.benchmark = True

    model = DenseNet121(N_CLASSES).cuda()

    ## getting the images list to be used for heatmap generation
    # image_names = []
    # with open(TEST_IMAGE_LIST, 'r') as f:
    #     for line in f:
    #         items = line.split()
    #         image_names.append(items[0])
    ##############################################################

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
        )
        state_dict = checkpoint['state_dict']
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        for key in list(state_dict.keys()):
            split_key = key.split('.')
            new_key = '.'.join(split_key[1:])
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
        model.load_state_dict(state_dict)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = sy.FederatedDataLoader(
        federated_dataset=test_dataset.federate(tuple(workers)),
        batch_size=BATCH_SIZE,
        shuffle=False)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()
    model_nosyft = model
    for i, (inp, target) in enumerate(test_loader):
        location = inp.location
        target = target.get().cuda()
        gt = torch.cat((gt, target), 0)
        inp = inp.get()
        bs, n_crops, c, h, w = inp.size()
        input_var = inp.view(-1, c, h, w).cuda().send(location)
        model.send(location)
        output = model(input_var)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        output_mean_data = output_mean.get()
        model.get()
        pred = torch.cat((pred, output_mean_data.data), 0)
    pred_np = pred.cpu().numpy()
    pred_np[pred_np < THRESHOLD] = 0
    pred_np[pred_np >= THRESHOLD] = 1
    indexes = []

    AUROCs, specificity, sensitivity, F1score = compute_AUCs(gt, pred)
    AUROC_avg = np.array(AUROCs).mean()
    print('The average AUROC is {AUROC_avg:.3f}'.format(AUROC_avg=AUROC_avg))
    for i in range(N_CLASSES):
        print('The AUROC of {} is {}'.format(CLASS_NAMES[i], AUROCs[i]))
        print('The specificity of {} is {}'.format(CLASS_NAMES[i],
                                                   specificity[i]))
        print('The sensitivity of {} is {}'.format(CLASS_NAMES[i],
                                                   sensitivity[i]))
        print('The F1 score of {} is {}'.format(CLASS_NAMES[i], F1score[i]))

    for arr in pred_np:
        indexes.append([index for index, val in enumerate(arr) if val == 1])
    for index, disease_array in enumerate(indexes):
        if len(disease_array) > 0:
            print(
                f'XRAY :: {index + 1} :: {[CLASS_NAMES[i] for i in disease_array]}'
            )
        else:
            print(f'XRAY :: {index + 1} :: No Disease detected')
        # print("Confidence of the 14 classes are {}".format(pred[i] * 100))
        print("Confidence of the 14 classes are {}".format({
            CLASS_NAMES[j]: pred[index][j] * 100
            for j in range(len(pred[index]))
        }))
        get_heat_map(model_nosyft, image_names[index], DATA_DIR)
Exemplo n.º 2
0
def main():

    cudnn.benchmark = True

    # initialize and load the model
    model = DenseNet121(N_CLASSES).cuda()
    # model = torch.nn.DataParallel(model).cuda()

    ## getting the images list to be used for heatmap generation
    image_names = []
    with open(TEST_IMAGE_LIST, 'r') as f:
        for line in f:
            items = line.split()
            image_names.append(items[0])
    ##############################################################

    if os.path.isfile(CKPT_PATH):
        print("=> loading checkpoint")
        checkpoint = torch.load(CKPT_PATH)
        pattern = re.compile(
            r'^(.*denselayer\d+\.(?:norm|relu|conv))\.((?:[12])\.(?:weight|bias|running_mean|running_var))$'
        )
        state_dict = checkpoint['state_dict']
        for key in list(state_dict.keys()):
            res = pattern.match(key)
            if res:
                new_key = res.group(1) + res.group(2)
                state_dict[new_key] = state_dict[key]
                del state_dict[key]
        for key in list(state_dict.keys()):
            split_key = key.split('.')
            new_key = '.'.join(split_key[1:])
            state_dict[new_key] = state_dict[key]
            del state_dict[key]
        model.load_state_dict(state_dict)
        print("=> loaded checkpoint")
    else:
        print("=> no checkpoint found")

    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

    test_dataset = ChestXrayDataSet(
        data_dir=DATA_DIR,
        image_list_file=TEST_IMAGE_LIST,
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.TenCrop(224),
            transforms.Lambda(lambda crops: torch.stack(
                [transforms.ToTensor()(crop) for crop in crops])),
            transforms.Lambda(
                lambda crops: torch.stack([normalize(crop) for crop in crops]))
        ]))
    test_loader = sy.FederatedDataLoader(
        federated_dataset=test_dataset.federate((bob, alice)),
        batch_size=BATCH_SIZE,
        shuffle=False)

    # initialize the ground truth and output tensor
    gt = torch.FloatTensor()
    gt = gt.cuda()
    pred = torch.FloatTensor()
    pred = pred.cuda()

    # switch to evaluate mode
    model.eval()
    model_nosyft = model
    for i, (inp, target) in enumerate(test_loader):
        location = inp.location
        inp = inp.get()
        bs, n_crops, c, h, w = inp.size()
        # print(inp.size())
        # input_var = torch.autograd.Variable(torch.FloatTensor(inp.view(-1, c, h, w)).cuda(), volatile=True).send(location)
        input_var = inp.view(-1, c, h, w).cuda().send(location)
        model.send(location)
        output = model(input_var)
        # print(f'output of the inference is {output.shape}')
        output_mean = output.view(bs, n_crops, -1).mean(1)
        output_mean_data = output_mean.get()
        model.get()
        pred = torch.cat((pred, output_mean_data.data), 0)
    pred_np = pred.cpu().numpy()
    pred_np[pred_np < THRESHOLD] = 0
    pred_np[pred_np >= THRESHOLD] = 1
    indexes = []
    for arr in pred_np:
        indexes.append([index for index, val in enumerate(arr) if val == 1])
    for index, disease_array in enumerate(indexes):
        if len(disease_array) > 0:
            print(
                f'XRAY :: {index + 1} :: {[CLASS_NAMES[i] for i in disease_array]}'
            )
        else:
            print(f'XRAY :: {index + 1} :: No Disease detected')
        get_heat_map(model_nosyft, image_names[index], DATA_DIR)